浏览代码

feat: Pipeline Parallel support (#581)

* wip

Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>

* support all models

* fix: ray

---------

Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
AlpinDale 7 月之前
父节点
当前提交
ae04f57ec1
共有 65 个文件被更改,包括 821 次插入349 次删除
  1. 21 3
      aphrodite/common/config.py
  2. 31 0
      aphrodite/common/sequence.py
  3. 28 22
      aphrodite/distributed/parallel_state.py
  4. 10 1
      aphrodite/distributed/utils.py
  5. 50 16
      aphrodite/engine/aphrodite_engine.py
  6. 63 16
      aphrodite/engine/async_aphrodite.py
  7. 1 1
      aphrodite/engine/output_processor/interfaces.py
  8. 3 2
      aphrodite/engine/output_processor/multi_step.py
  9. 13 7
      aphrodite/engine/output_processor/single_step.py
  10. 6 6
      aphrodite/executor/distributed_gpu_executor.py
  11. 25 0
      aphrodite/executor/executor_base.py
  12. 2 1
      aphrodite/executor/gpu_executor.py
  13. 6 6
      aphrodite/executor/multiproc_gpu_executor.py
  14. 71 18
      aphrodite/executor/ray_gpu_executor.py
  15. 2 1
      aphrodite/modeling/models/arctic.py
  16. 2 1
      aphrodite/modeling/models/baichuan.py
  17. 2 1
      aphrodite/modeling/models/bloom.py
  18. 2 1
      aphrodite/modeling/models/chatglm.py
  19. 2 1
      aphrodite/modeling/models/commandr.py
  20. 2 1
      aphrodite/modeling/models/dbrx.py
  21. 2 1
      aphrodite/modeling/models/deepseek.py
  22. 2 1
      aphrodite/modeling/models/deepseek_v2.py
  23. 2 1
      aphrodite/modeling/models/falcon.py
  24. 3 2
      aphrodite/modeling/models/gemma.py
  25. 2 1
      aphrodite/modeling/models/gpt2.py
  26. 2 1
      aphrodite/modeling/models/gpt_bigcode.py
  27. 2 1
      aphrodite/modeling/models/gpt_j.py
  28. 2 1
      aphrodite/modeling/models/gpt_neox.py
  29. 2 1
      aphrodite/modeling/models/internlm2.py
  30. 2 1
      aphrodite/modeling/models/jais.py
  31. 72 28
      aphrodite/modeling/models/llama.py
  32. 4 2
      aphrodite/modeling/models/llava.py
  33. 3 1
      aphrodite/modeling/models/llava_next.py
  34. 2 1
      aphrodite/modeling/models/minicpm.py
  35. 2 1
      aphrodite/modeling/models/mixtral.py
  36. 2 1
      aphrodite/modeling/models/mixtral_quant.py
  37. 2 1
      aphrodite/modeling/models/mpt.py
  38. 2 1
      aphrodite/modeling/models/olmo.py
  39. 2 1
      aphrodite/modeling/models/opt.py
  40. 2 1
      aphrodite/modeling/models/orion.py
  41. 2 1
      aphrodite/modeling/models/phi.py
  42. 2 1
      aphrodite/modeling/models/phi3_small.py
  43. 8 3
      aphrodite/modeling/models/phi3v.py
  44. 2 1
      aphrodite/modeling/models/qwen.py
  45. 2 1
      aphrodite/modeling/models/qwen2.py
  46. 2 1
      aphrodite/modeling/models/qwen2_moe.py
  47. 2 1
      aphrodite/modeling/models/stablelm.py
  48. 2 1
      aphrodite/modeling/models/starcoder2.py
  49. 2 1
      aphrodite/modeling/models/xverse.py
  50. 3 0
      aphrodite/processing/block_manager_v1.py
  51. 3 0
      aphrodite/processing/block_manager_v2.py
  52. 11 2
      aphrodite/processing/scheduler.py
  53. 10 5
      aphrodite/spec_decode/draft_model_runner.py
  54. 4 0
      aphrodite/task_handler/cache_engine.py
  55. 4 1
      aphrodite/task_handler/cpu_model_runner.py
  56. 24 14
      aphrodite/task_handler/cpu_worker.py
  57. 8 3
      aphrodite/task_handler/embedding_model_runner.py
  58. 201 125
      aphrodite/task_handler/model_runner.py
  59. 4 1
      aphrodite/task_handler/model_runner_base.py
  60. 4 1
      aphrodite/task_handler/neuron_model_runner.py
  61. 1 1
      aphrodite/task_handler/neuron_worker.py
  62. 27 18
      aphrodite/task_handler/worker.py
  63. 31 9
      aphrodite/task_handler/worker_base.py
  64. 4 2
      aphrodite/task_handler/xpu_model_runner.py
  65. 2 2
      aphrodite/task_handler/xpu_worker.py

+ 21 - 3
aphrodite/common/config.py

@@ -29,6 +29,16 @@ APHRODITE_USE_MODELSCOPE = os.environ.get("APHRODITE_USE_MODELSCOPE",
 _GB = 1 << 30
 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
 
+_PP_SUPPORTED_MODELS = [
+    "AquilaModel",
+    "AquilaForCausalLM",
+    "InternLMForCausalLM",
+    "LlamaForCausalLM",
+    "LLaMAForCausalLM",
+    "MistralForCausalLM",
+    "Phi3ForCausalLM",
+]
+
 
 class ModelConfig:
     """Configuration for the model.
@@ -338,6 +348,13 @@ class ModelConfig:
         total_num_hidden_layers = getattr(self.hf_text_config,
                                           "num_hidden_layers", 0)
         pipeline_parallel_size = parallel_config.pipeline_parallel_size
+        architectures = getattr(self.hf_config, "architectures", [])
+        if not all(arch in _PP_SUPPORTED_MODELS
+                   for arch in architectures) and pipeline_parallel_size > 1:
+            raise NotImplementedError(
+                "Pipeline parallelism is only supported for the following "
+                f" architectures: {_PP_SUPPORTED_MODELS}.")
+
         if total_num_hidden_layers % pipeline_parallel_size != 0:
             raise ValueError(
                 f"Total number of hidden layers ({total_num_hidden_layers}) "
@@ -747,9 +764,10 @@ class ParallelConfig:
         self._verify_args()
 
     def _verify_args(self) -> None:
-        if self.pipeline_parallel_size > 1:
-            raise NotImplementedError(
-                "Pipeline parallelism is not supported yet.")
+        if (self.pipeline_parallel_size > 1
+                and self.distributed_executor_backend == "mp"):
+            raise NotImplementedError("Pipeline parallelism is not supported "
+                                      "yet with multiprocessing.")
         if self.distributed_executor_backend not in ("ray", "mp", None):
             raise ValueError(
                 "Unrecognized distributed executor backend. Supported values "

+ 31 - 0
aphrodite/common/sequence.py

@@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
         return self.embeddings == other.embeddings
 
 
+@dataclass
+class IntermediateTensors:
+    """For all pipeline stages except the last, we need to return the hidden
+    states and residuals to be sent to the next stage. This data structure
+    contains the hidden states and residuals for a request.
+    """
+
+    tensors: Dict[str, torch.Tensor]
+
+    def __getitem__(self, key: Union[str, slice]):
+        if isinstance(key, str):
+            return self.tensors[key]
+        elif isinstance(key, slice):
+            return self.__class__({k: v[key] for k, v in self.tensors.items()})
+
+    def __setitem__(self, key: str, value):
+        self.tensors[key] = value
+
+    def __len__(self):
+        return len(self.tensors)
+
+    def __eq__(self, other: object):
+        return isinstance(other, self.__class__) and self
+
+    def __repr__(self) -> str:
+        return f"IntermediateTensors(tensors={self.tensors})"
+
+
 @dataclass
 class SamplerOutput:
     """For each sequence group, we generate a list of SequenceOutput object,
@@ -896,6 +924,8 @@ class ExecuteModelRequest:
     blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
     # Blocks to copy. Source to dest block.
     blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
+    # Virtual engine ID for pipeline parallel.
+    virtual_engine: int = 0
     # The number of slots for lookahead decoding.
     num_lookahead_slots: int = 0
     # The number of requests in the running queue.
@@ -914,6 +944,7 @@ class ExecuteModelRequest:
             blocks_to_swap_in=self.blocks_to_swap_in.copy(),
             blocks_to_swap_out=self.blocks_to_swap_out.copy(),
             blocks_to_copy=self.blocks_to_copy.copy(),
+            virtual_engine=self.virtual_engine,
             num_lookahead_slots=self.num_lookahead_slots,
             running_queue_size=self.running_queue_size,
             previous_hidden_states=self.previous_hidden_states,

+ 28 - 22
aphrodite/distributed/parallel_state.py

@@ -415,7 +415,7 @@ class GroupCoordinator:
 
         assert dst < self.world_size, f"Invalid dst rank ({dst})"
 
-        assert dst != self.rank, (
+        assert dst != self.rank_in_group, (
             "Invalid destination rank. Destination rank is the same "
             "as the current rank.")
 
@@ -445,7 +445,7 @@ class GroupCoordinator:
 
         assert src < self.world_size, f"Invalid src rank ({src})"
 
-        assert src != self.rank, (
+        assert src != self.rank_in_group, (
             "Invalid source rank. Source rank is the same as the current rank."
         )
 
@@ -453,7 +453,7 @@ class GroupCoordinator:
 
         # Receive object size
         rank_size = torch.distributed.recv(size_tensor,
-                                           src=src,
+                                           src=self.ranks[src],
                                            group=self.cpu_group)
 
         # Tensor to receive serialized objects into.
@@ -463,7 +463,7 @@ class GroupCoordinator:
             device="cpu")
 
         rank_object = torch.distributed.recv(object_tensor,
-                                             src=src,
+                                             src=self.ranks[src],
                                              group=self.cpu_group)
 
         assert rank_object == rank_size, (
@@ -490,10 +490,9 @@ class GroupCoordinator:
         group = self.device_group
         metadata_group = self.cpu_group
         assert src < self.world_size, f"Invalid src rank ({src})"
-        src = self.ranks[src]
 
-        rank = self.rank
-        if rank == src:
+        rank_in_group = self.rank_in_group
+        if rank_in_group == src:
             metadata_list: List[Tuple[Any, Any]] = []
             assert isinstance(
                 tensor_dict,
@@ -511,13 +510,13 @@ class GroupCoordinator:
                 if tensor.is_cpu:
                     # use metadata_group for CPU tensors
                     handle = torch.distributed.broadcast(tensor,
-                                                         src=src,
+                                                         src=self.ranks[src],
                                                          group=metadata_group,
                                                          async_op=True)
                 else:
                     # use group for GPU tensors
                     handle = torch.distributed.broadcast(tensor,
-                                                         src=src,
+                                                         src=self.ranks[src],
                                                          group=group,
                                                          async_op=True)
                 async_handles.append(handle)
@@ -541,15 +540,16 @@ class GroupCoordinator:
                         # use metadata_group for CPU tensors
                         handle = torch.distributed.broadcast(
                             tensor,
-                            src=src,
+                            src=self.ranks[src],
                             group=metadata_group,
                             async_op=True)
                     else:
                         # use group for GPU tensors
-                        handle = torch.distributed.broadcast(tensor,
-                                                             src=src,
-                                                             group=group,
-                                                             async_op=True)
+                        handle = torch.distributed.broadcast(
+                            tensor,
+                            src=self.ranks[src],
+                            group=group,
+                            async_op=True)
                     async_handles.append(handle)
                     _update_nested_dict(tensor_dict, key, tensor)
                 else:
@@ -574,7 +574,7 @@ class GroupCoordinator:
         metadata_group = self.cpu_group
 
         if dst is None:
-            dst = self.next_rank
+            dst = (self.rank_in_group + 1) % self.world_size
         assert dst < self.world_size, f"Invalid dst rank ({dst})"
 
         metadata_list: List[Tuple[Any, Any]] = []
@@ -592,10 +592,14 @@ class GroupCoordinator:
                 continue
             if tensor.is_cpu:
                 # use metadata_group for CPU tensors
-                torch.distributed.send(tensor, dst=dst, group=metadata_group)
+                torch.distributed.send(tensor,
+                                       dst=self.ranks[dst],
+                                       group=metadata_group)
             else:
                 # use group for GPU tensors
-                torch.distributed.send(tensor, dst=dst, group=group)
+                torch.distributed.send(tensor,
+                                       dst=self.ranks[dst],
+                                       group=group)
         return None
 
     def recv_tensor_dict(
@@ -613,7 +617,7 @@ class GroupCoordinator:
         metadata_group = self.cpu_group
 
         if src is None:
-            src = self.prev_rank
+            src = (self.rank_in_group - 1) % self.world_size
         assert src < self.world_size, f"Invalid src rank ({src})"
 
         recv_metadata_list = self.recv_object(src=src)
@@ -630,11 +634,13 @@ class GroupCoordinator:
                 if tensor.is_cpu:
                     # use metadata_group for CPU tensors
                     torch.distributed.recv(tensor,
-                                           src=src,
+                                           src=self.ranks[src],
                                            group=metadata_group)
                 else:
                     # use group for GPU tensors
-                    torch.distributed.recv(tensor, src=src, group=group)
+                    torch.distributed.recv(tensor,
+                                           src=self.ranks[src],
+                                           group=group)
                 _update_nested_dict(tensor_dict, key, tensor)
             else:
                 _update_nested_dict(tensor_dict, key, value)
@@ -653,7 +659,7 @@ class GroupCoordinator:
         """Sends a tensor to the destination rank in a non-blocking way"""
         """NOTE: `dst` is the local rank of the destination rank."""
         if dst is None:
-            dst = self.next_rank
+            dst = (self.rank_in_group + 1) % self.world_size
 
         pynccl_comm = self.pynccl_comm
         if pynccl_comm is not None and not pynccl_comm.disabled:
@@ -668,7 +674,7 @@ class GroupCoordinator:
         """Receives a tensor from the src rank."""
         """NOTE: `src` is the local rank of the destination rank."""
         if src is None:
-            src = self.prev_rank
+            src = (self.rank_in_group - 1) % self.world_size
 
         tensor = torch.empty(size, dtype=dtype, device=self.device)
         pynccl_comm = self.pynccl_comm

+ 10 - 1
aphrodite/distributed/utils.py

@@ -3,7 +3,7 @@
 # Adapted from
 # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-from typing import Sequence
+from typing import Sequence, Tuple
 
 import torch
 
@@ -47,3 +47,12 @@ def split_tensor_along_last_dim(
         return tuple(chunk.contiguous() for chunk in tensor_list)
 
     return tensor_list
+
+
+def get_pp_indices(num_hidden_layers: int, pp_rank: int,
+                   pp_size: int) -> Tuple[int, int]:
+    layers_per_partition = divide(num_hidden_layers, pp_size)
+    start_layer = pp_rank * layers_per_partition
+    end_layer = start_layer + layers_per_partition
+
+    return (start_layer, end_layer)

+ 50 - 16
aphrodite/engine/aphrodite_engine.py

@@ -168,7 +168,8 @@ class AphroditeEngine:
             f"Speculative Config = {speculative_config!r}\n"
             f"DataType = {model_config.dtype}\n"
             f"Model Load Format = {load_config.load_format}\n"
-            f"Number of GPUs = {parallel_config.tensor_parallel_size}\n"
+            f"Tensor Parallel Size = {parallel_config.tensor_parallel_size}\n"
+            f"Pipeline Parallel Size = {parallel_config.pipeline_parallel_size}\n"  # noqa: E501
             f"Disable Custom All-Reduce = "
             f"{parallel_config.disable_custom_all_reduce}\n"
             f"Quantization Format = {model_config.quantization}\n"
@@ -231,7 +232,11 @@ class AphroditeEngine:
         # Create the scheduler.
         # NOTE: the cache_config here have been updated with the numbers of
         # GPU and CPU blocks, which are profiled in the distributed executor.
-        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
+        self.scheduler = [
+            Scheduler(scheduler_config, cache_config, lora_config,
+                      parallel_config.pipeline_parallel_size)
+            for _ in range(parallel_config.pipeline_parallel_size)
+        ]
 
         # Metric Logging.
         if self.log_stats:
@@ -435,8 +440,16 @@ class AphroditeEngine:
             raise ValueError(
                 "Either SamplingParams or PoolingParams must be provided.")
 
-        # Add the sequence group to the scheduler.
-        self.scheduler.add_seq_group(seq_group)
+        # Add the sequence group to the scheduler with least unfinished seqs.
+        costs = [
+            scheduler.get_num_unfinished_seq_groups()
+            for scheduler in self.scheduler
+        ]
+        min_cost_scheduler = self.scheduler[costs.index(min(costs))]
+        min_cost_scheduler.add_seq_group(seq_group)
+
+    def stop_remote_worker_execution_loop(self) -> None:
+        self.model_executor.stop_remote_worker_execution_loop()
 
     def process_model_inputs(
         self,
@@ -601,7 +614,8 @@ class AphroditeEngine:
             >>> # abort the request
             >>> engine.abort_request(request_id)
         """
-        self.scheduler.abort_seq_group(request_id)
+        for scheduler in self.scheduler:
+            scheduler.abort_seq_group(request_id)
 
     def get_model_config(self) -> ModelConfig:
         """Gets the model configuration."""
@@ -613,11 +627,20 @@ class AphroditeEngine:
 
     def get_num_unfinished_requests(self) -> int:
         """Gets the number of unfinished requests."""
-        return self.scheduler.get_num_unfinished_seq_groups()
+        return sum(scheduler.get_num_unfinished_seq_groups()
+                   for scheduler in self.scheduler)
 
     def has_unfinished_requests(self) -> bool:
         """Returns True if there are unfinished requests."""
-        return self.scheduler.has_unfinished_seqs()
+        return any(scheduler.has_unfinished_seqs()
+                   for scheduler in self.scheduler)
+
+    def has_unfinished_requests_for_virtual_engine(
+            self, virtual_engine: int) -> bool:
+        """
+        Returns True if there are unfinished requests for the virtual engine.
+        """
+        return self.scheduler[virtual_engine].has_unfinished_seqs()
 
     def _process_sequence_group_outputs(
         self,
@@ -666,7 +689,8 @@ class AphroditeEngine:
                 self.output_processor.process_outputs(seq_group, outputs)
 
         # Free the finished sequence groups.
-        self.scheduler.free_finished_seq_groups()
+        for scheduler in self.scheduler:
+            scheduler.free_finished_seq_groups()
 
         # Create the outputs.
         request_outputs: List[Union[RequestOutput,
@@ -732,7 +756,12 @@ class AphroditeEngine:
             >>>     if not (engine.has_unfinished_requests() or example_inputs):
             >>>         break
         """
-        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
+        if self.parallel_config.pipeline_parallel_size > 1:
+            raise NotImplementedError(
+                "Pipeline parallelism is only supported through AsyncAphrodite "
+                "as performance will be severely degraded otherwise.")
+        seq_group_metadata_list, scheduler_outputs = self.scheduler[
+            0].schedule()
 
         if not scheduler_outputs.is_empty():
             execute_model_req = ExecuteModelRequest(
@@ -800,23 +829,28 @@ class AphroditeEngine:
 
         # System State
         #   Scheduler State
-        num_running_sys = len(self.scheduler.running)
-        num_swapped_sys = len(self.scheduler.swapped)
-        num_waiting_sys = len(self.scheduler.waiting)
+        num_running_sys = sum(
+            len(scheduler.running) for scheduler in self.scheduler)
+        num_swapped_sys = sum(
+            len(scheduler.swapped) for scheduler in self.scheduler)
+        num_waiting_sys = sum(
+            len(scheduler.waiting) for scheduler in self.scheduler)
 
         # KV Cache Usage in %
         num_total_gpu = self.cache_config.num_gpu_blocks
         gpu_cache_usage_sys = 0.
         if num_total_gpu is not None:
-            num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks(
-            )
+            num_free_gpu = sum(
+                scheduler.block_manager.get_num_free_gpu_blocks()
+                for scheduler in self.scheduler)
             gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
 
         num_total_cpu = self.cache_config.num_cpu_blocks
         cpu_cache_usage_sys = 0.
         if num_total_cpu is not None and num_total_cpu > 0:
-            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
-            )
+            num_free_cpu = sum(
+                scheduler.block_manager.get_num_free_cpu_blocks()
+                for scheduler in self.scheduler)
             cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
 
         # Iteration stats

+ 63 - 16
aphrodite/engine/async_aphrodite.py

@@ -209,7 +209,8 @@ class _AsyncAphrodite(AphroditeEngine):
     """Extension of AphroditeEngine to add async methods."""
 
     async def step_async(
-            self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
+        self, virtual_engine: int
+    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
         """Performs one decoding iteration and returns newly generated results.
         The workers are ran asynchronously if possible.
 
@@ -219,7 +220,8 @@ class _AsyncAphrodite(AphroditeEngine):
         and updates the scheduler with the model outputs. Finally, it decodes
         the sequences and returns the newly generated results.
         """
-        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
+        seq_group_metadata_list, scheduler_outputs = self.scheduler[
+            virtual_engine].schedule()
 
         if not scheduler_outputs.is_empty():
             # Execute the model.
@@ -228,6 +230,7 @@ class _AsyncAphrodite(AphroditeEngine):
                 blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                 blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                 blocks_to_copy=scheduler_outputs.blocks_to_copy,
+                virtual_engine=virtual_engine,
                 num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                 running_queue_size=scheduler_outputs.running_queue_size,
             )
@@ -243,16 +246,12 @@ class _AsyncAphrodite(AphroditeEngine):
         # Log stats.
         self.do_log_stats(scheduler_outputs, output)
 
-        if not request_outputs:
-            # Stop the execute model loop in parallel workers until there are
-            # more requests to process. This avoids waiting indefinitely in
-            # torch.distributed ops which may otherwise timeout, and unblocks
-            # the RPC thread in the workers so that they can process any other
-            # queued control plane messages, such as add/remove lora adapters.
-            await self.model_executor.stop_remote_worker_execution_loop_async()
-
         return request_outputs
 
+    async def stop_remote_worker_execution_loop_async(self) -> None:
+        """Stop the remote worker execution loop."""
+        await self.model_executor.stop_remote_worker_execution_loop_async()
+
     async def process_model_inputs_async(
         self,
         request_id: str,
@@ -487,7 +486,8 @@ class AsyncAphrodite:
             # order of the arguments.
             cache_config = kwargs["cache_config"]
             parallel_config = kwargs["parallel_config"]
-            if parallel_config.tensor_parallel_size == 1:
+            if (parallel_config.tensor_parallel_size == 1
+                    and parallel_config.pipeline_parallel_size == 1):
                 num_gpus = cache_config.gpu_memory_utilization
             else:
                 num_gpus = 1
@@ -495,7 +495,7 @@ class AsyncAphrodite:
                 self._engine_class).remote
         return engine_class(*args, **kwargs)
 
-    async def engine_step(self) -> bool:
+    async def engine_step(self, virtual_engine: int) -> bool:
         """Kick the engine to process the waiting requests.
 
         Returns True if there are in-progress requests."""
@@ -526,7 +526,7 @@ class AsyncAphrodite:
         if self.engine_use_ray:
             request_outputs = await self.engine.step.remote()  # type: ignore
         else:
-            request_outputs = await self.engine.step_async()
+            request_outputs = await self.engine.step_async(virtual_engine)
 
         # Put the outputs into the corresponding streams.
         for request_output in request_outputs:
@@ -542,18 +542,65 @@ class AsyncAphrodite:
             self.engine.abort_request(request_ids)
 
     async def run_engine_loop(self):
-        has_requests_in_progress = False
+        if self.engine_use_ray:
+            pipeline_parallel_size = 1  # type: ignore
+        else:
+            pipeline_parallel_size = \
+                self.engine.parallel_config.pipeline_parallel_size
+        has_requests_in_progress = [False] * pipeline_parallel_size
         while True:
-            if not has_requests_in_progress:
+            if not any(has_requests_in_progress):
                 logger.debug("Waiting for new requests...")
+                # Stop the execute model loop in parallel workers until there
+                # are more requests to process. This avoids waiting
+                # indefinitely in torch.distributed ops which may otherwise
+                # timeout, and unblocks the RPC thread in the workers so that
+                # they can process any other queued control plane messages,
+                # such as add/remove lora adapters.
+                if self.engine_use_ray:
+                    await (self.engine.stop_remote_worker_execution_loop.
+                           remote()  # type: ignore
+                           )
+                else:
+                    await self.engine.stop_remote_worker_execution_loop_async()
                 await self._request_tracker.wait_for_new_requests()
                 logger.debug("Got new requests!")
+                requests_in_progress = [
+                    asyncio.create_task(self.engine_step(ve))
+                    for ve in range(pipeline_parallel_size)
+                ]
+                has_requests_in_progress = [True] * pipeline_parallel_size
 
             # Abort if iteration takes too long due to unrecoverable errors
             # (eg. NCCL timeouts).
             try:
                 async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
-                    has_requests_in_progress = await self.engine_step()
+                    done, _ = await asyncio.wait(
+                        requests_in_progress,
+                        return_when=asyncio.FIRST_COMPLETED)
+                    for _ in range(pipeline_parallel_size):
+                        await asyncio.sleep(0)
+                for task in done:
+                    result = task.result()
+                    virtual_engine = requests_in_progress.index(task)
+                    if self.engine_use_ray:
+                        has_unfinished_requests = (
+                            await (self.engine.
+                                   has_unfinished_requests_for_virtual_engine.
+                                   remote(  # type: ignore
+                                       virtual_engine)))
+                    else:
+                        has_unfinished_requests = (
+                            self.engine.
+                            has_unfinished_requests_for_virtual_engine(
+                                virtual_engine))
+                    if result or has_unfinished_requests:
+                        requests_in_progress[virtual_engine] = (
+                            asyncio.create_task(
+                                self.engine_step(virtual_engine)))
+                        has_requests_in_progress[virtual_engine] = True
+                    else:
+                        has_requests_in_progress[virtual_engine] = False
             except asyncio.TimeoutError as exc:
                 logger.error(
                     "Engine iteration timed out. This should never happen!")

+ 1 - 1
aphrodite/engine/output_processor/interfaces.py

@@ -28,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
     def create_output_processor(
         scheduler_config: SchedulerConfig,
         detokenizer: Detokenizer,
-        scheduler: Scheduler,
+        scheduler: List[Scheduler],
         seq_counter: Counter,
         get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
         stop_checker: "StopChecker",

+ 3 - 2
aphrodite/engine/output_processor/multi_step.py

@@ -33,7 +33,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
     def __init__(
         self,
         detokenizer: Detokenizer,
-        scheduler: Scheduler,
+        scheduler: List[Scheduler],
         seq_counter: Counter,
         get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
         stop_checker: StopChecker,
@@ -139,4 +139,5 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
                 break
 
         if seq.is_finished():
-            self.scheduler.free_seq(seq)
+            for scheduler in self.scheduler:
+                scheduler.free_seq(seq)

+ 13 - 7
aphrodite/engine/output_processor/single_step.py

@@ -31,7 +31,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
         self,
         scheduler_config: SchedulerConfig,
         detokenizer: Detokenizer,
-        scheduler: Scheduler,
+        scheduler: List[Scheduler],
         seq_counter: Counter,
         stop_checker: StopChecker,
     ):
@@ -93,7 +93,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                 # not be used in the future iterations.
                 parent.status = SequenceStatus.FINISHED_ABORTED
                 seq_group.remove(parent.seq_id)
-                self.scheduler.free_seq(parent)
+                for scheduler in self.scheduler:
+                    scheduler.free_seq(parent)
                 continue
             # Fork the parent sequence if there are multiple child samples.
             for child_sample in child_samples[:-1]:
@@ -131,7 +132,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                 if seq is not parent:
                     seq_group.add(seq)
                     if not seq.is_finished():
-                        self.scheduler.fork_seq(parent, seq)
+                        for scheduler in self.scheduler:
+                            scheduler.fork_seq(parent, seq)
 
             # Free the finished and selected parent sequences' memory in block
             # manager. Keep them in the sequence group as candidate output.
@@ -139,7 +141,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
             # old sequences.
             for seq, parent in child_seqs:
                 if seq is parent and seq.is_finished():
-                    self.scheduler.free_seq(seq)
+                    for scheduler in self.scheduler:
+                        scheduler.free_seq(seq)
             return
 
         # Beam search case
@@ -224,13 +227,15 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
             if seq is not parent:
                 seq_group.add(seq)
                 if not seq.is_finished():
-                    self.scheduler.fork_seq(parent, seq)
+                    for scheduler in self.scheduler:
+                        scheduler.fork_seq(parent, seq)
 
         # Free the finished and selected parent sequences' memory in block
         # manager. Keep them in the sequence group as candidate output.
         for seq, parent in selected_child_seqs:
             if seq is parent and seq.is_finished():
-                self.scheduler.free_seq(seq)
+                for scheduler in self.scheduler:
+                    scheduler.free_seq(seq)
 
         # Remove the unselected parent sequences from the sequence group and
         # free their memory in block manager.
@@ -239,7 +244,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
                 # Remove the parent sequence if it is not selected for next
                 # iteration
                 seq_group.remove(seq.seq_id)
-                self.scheduler.free_seq(seq)
+                for scheduler in self.scheduler:
+                    scheduler.free_seq(seq)
 
     def _check_beam_search_early_stopping(
         self,

+ 6 - 6
aphrodite/executor/distributed_gpu_executor.py

@@ -69,7 +69,7 @@ class DistributedGPUExecutor(GPUExecutor):
         if self.parallel_worker_tasks is None:
             self.parallel_worker_tasks = self._run_workers(
                 "start_worker_execution_loop",
-                async_run_remote_workers_only=True,
+                async_run_tensor_parallel_workers_only=True,
                 **self.extra_execute_model_run_workers_kwargs)
 
         # Only the driver worker returns the sampling results.
@@ -137,16 +137,16 @@ class DistributedGPUExecutor(GPUExecutor):
         self,
         method: str,
         *args,
-        async_run_remote_workers_only: bool = False,
+        async_run_tensor_parallel_workers_only: bool = False,
         max_concurrent_workers: Optional[int] = None,
         **kwargs,
     ) -> Any:
         """Runs the given method on all workers.
         Args:
-            async_run_remote_workers_only: If True the method will be run only
-                in the remote workers, not the driver worker. It will also be
-                run asynchronously and return a list of futures rather than
-                blocking on the results.
+            async_run_tensor_parallel_workers_only: If True the method will be
+                run only in the remote TP workers, not the driver worker.
+                It will also be run asynchronously and return a list of futures
+                rather than blocking on the results.
         """
         raise NotImplementedError
 

+ 25 - 0
aphrodite/executor/executor_base.py

@@ -1,3 +1,4 @@
+import asyncio
 from abc import ABC, abstractmethod
 from typing import List, Optional, Set, Tuple
 
@@ -111,6 +112,30 @@ class ExecutorBase(ABC):
 
 class ExecutorAsyncBase(ExecutorBase):
 
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        load_config: LoadConfig,
+        lora_config: Optional[LoRAConfig],
+        vision_language_config: Optional[VisionLanguageConfig],
+        speculative_config: Optional[SpeculativeConfig],
+    ) -> None:
+        # This locks each pipeline parallel stage so multiple virtual engines
+        # can't execute on the same stage at the same time
+        self.pp_locks = [
+            asyncio.Lock()
+            for _ in range(parallel_config.pipeline_parallel_size)
+        ]
+
+        super().__init__(model_config, cache_config, parallel_config,
+                         scheduler_config, device_config, load_config,
+                         lora_config, vision_language_config,
+                         speculative_config)
+
     @abstractmethod
     async def execute_model_async(
             self,

+ 2 - 1
aphrodite/executor/gpu_executor.py

@@ -45,7 +45,8 @@ class GPUExecutor(ExecutorBase):
             lora_config=self.lora_config,
             vision_language_config=self.vision_language_config,
             speculative_config=self.speculative_config,
-            is_driver_worker=rank == 0,
+            is_driver_worker=(not self.parallel_config)
+            or (rank % self.parallel_config.tensor_parallel_size == 0),
         )
 
     def _create_worker(self,

+ 6 - 6
aphrodite/executor/multiproc_gpu_executor.py

@@ -88,16 +88,16 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
         self,
         method: str,
         *args,
-        async_run_remote_workers_only: bool = False,
+        async_run_tensor_parallel_workers_only: bool = False,
         max_concurrent_workers: Optional[int] = None,
         **kwargs,
     ) -> Any:
         """Runs the given method on all workers.
         Args:
-            async_run_remote_workers_only: If True the method will be run only
-                in the remote workers, not the driver worker. It will also be
-                run asynchronously and return a list of futures rather than
-                blocking on the results.
+            async_run_tensor_parallel_workers_only: If True the method will be
+                run only in the remote TP workers, not the driver worker.
+                It will also be run asynchronously and return a list of futures
+                rather than blocking on the results.
         """
 
         if max_concurrent_workers:
@@ -110,7 +110,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
             for worker in self.workers
         ]
 
-        if async_run_remote_workers_only:
+        if async_run_tensor_parallel_workers_only:
             # Just return futures
             return worker_outputs
 

+ 71 - 18
aphrodite/executor/ray_gpu_executor.py

@@ -9,7 +9,7 @@ from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
 from aphrodite.common.utils import (get_aphrodite_instance_id,
                                     get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
-from aphrodite.executor.distributed_gpu_executor import (
+from aphrodite.executor.distributed_gpu_executor import (  # yapf: disable
     DistributedGPUExecutor, DistributedGPUExecutorAsync)
 from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
 
@@ -23,12 +23,12 @@ if TYPE_CHECKING:
 # which optimizes the control plane overhead.
 # Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
 USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
+APHRODITE_TRACE_FUNCTION = int(os.getenv("APHRODITE_TRACE_FUNCTION", 0))
 
 
 class RayGPUExecutor(DistributedGPUExecutor):
 
     def _init_executor(self) -> None:
-
         assert self.parallel_config.distributed_executor_backend == "ray"
         placement_group = self.parallel_config.placement_group
 
@@ -63,7 +63,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
 
     def _init_workers_ray(self, placement_group: "PlacementGroup",
                           **ray_remote_kwargs):
-        if self.parallel_config.tensor_parallel_size == 1:
+        if (self.parallel_config.tensor_parallel_size == 1
+                and self.parallel_config.pipeline_parallel_size == 1):
             # For single GPU case, we use a ray worker with constrained memory.
             num_gpus = self.cache_config.gpu_memory_utilization
         else:
@@ -72,7 +73,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
 
         # The driver dummy worker does not actually use any resources.
         # It holds the resource for the driver worker.
-        self.driver_dummy_worker: RayWorkerWrapper = None
+        self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
         # The remaining workers are the actual ray actors.
         self.workers: List[RayWorkerWrapper] = []
 
@@ -90,12 +91,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
                 placement_group_capture_child_tasks=True,
                 placement_group_bundle_index=bundle_id,
             )
+
             if self.speculative_config is not None:
                 worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
                 worker_class_name = "create_spec_worker"
             else:
                 worker_module_name = "aphrodite.task_handler.worker"
                 worker_class_name = "Worker"
+
             worker = ray.remote(
                 num_cpus=0,
                 num_gpus=num_gpus,
@@ -135,6 +138,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
         node_gpus = defaultdict(list)
 
         for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
+            node_workers[node_id].append(i)
             # `gpu_ids` can be a list of strings or integers.
             # convert them to integers for consistency.
             # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
@@ -153,7 +157,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
             "APHRODITE_INSTANCE_ID":
             APHRODITE_INSTANCE_ID,
             "APHRODITE_TRACE_FUNCTION":
-            os.getenv("APHRODITE_TRACE_FUNCTION", "0"),
+            str(APHRODITE_TRACE_FUNCTION),
         }, ) for (node_id, _) in worker_node_and_gpu_ids]
         self._run_workers("update_environment_variables",
                           all_args=all_args_to_update_environment_variables)
@@ -186,10 +190,31 @@ class RayGPUExecutor(DistributedGPUExecutor):
                           max_concurrent_workers=self.parallel_config.
                           max_parallel_loading_workers)
 
+        # This is the list of workers that are rank 0 of each TP group EXCEPT
+        # global rank 0. These are the workers that will broadcast to the
+        # rest of the workers.
+        self.tp_driver_workers: List[RayWorkerWrapper] = []
+        # This is the list of workers that are not drivers and not the first
+        # worker in a TP group. These are the workers that will be
+        # broadcasted to.
+        self.non_driver_workers: List[RayWorkerWrapper] = []
+
+        for pp_rank in range(self.parallel_config.pipeline_parallel_size):
+            for tp_rank in range(self.parallel_config.tensor_parallel_size):
+                rank = (pp_rank *
+                        self.parallel_config.tensor_parallel_size) + tp_rank
+                if rank == 0:
+                    pass
+                elif rank % self.parallel_config.tensor_parallel_size == 0:
+                    self.tp_driver_workers.append(self.workers[rank - 1])
+                else:
+                    self.non_driver_workers.append(self.workers[rank - 1])
+
     def _driver_execute_model(
         self, execute_model_req: Optional[ExecuteModelRequest]
     ) -> Optional[List[SamplerOutput]]:
         """Run execute_model in the driver worker.
+
         Passing None will cause the driver to stop the model execution
         loop running in each of the remote workers.
         """
@@ -200,7 +225,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
         self,
         method: str,
         *args,
-        async_run_remote_workers_only: bool = False,
+        async_run_tensor_parallel_workers_only: bool = False,
         all_args: Optional[List[Tuple[Any, ...]]] = None,
         all_kwargs: Optional[List[Dict[str, Any]]] = None,
         use_dummy_driver: bool = False,
@@ -211,10 +236,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
         """Runs the given method on all workers. Can be used in the following
         ways:
 
-        - async_run_remote_workers_only: If True the method will be run only
-          in the remote workers, not the driver worker. It will also be
-          run asynchronously and return a list of futures rather than blocking
-          on the results.
+        Args:
+        - async_run_tensor_parallel_workers_only: If True the method will be
+          run only in the remote TP workers, not the driver worker.
+          It will also be run asynchronously and return a list of futures
+          rather than blocking on the results.
         - args/kwargs: All workers share the same args/kwargs
         - all_args/all_kwargs: args/kwargs for each worker are specified
           individually
@@ -224,7 +250,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
             raise NotImplementedError(
                 "max_concurrent_workers is not supported yet.")
 
-        count = len(self.workers)
+        count = len(self.workers) if not \
+            async_run_tensor_parallel_workers_only \
+            else len(self.non_driver_workers)
         all_worker_args = repeat(args, count) if all_args is None \
             else islice(all_args, 1, None)
         all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
@@ -238,14 +266,17 @@ class RayGPUExecutor(DistributedGPUExecutor):
             ray_worker_outputs = []
         else:
             # Start the ray workers first.
+            ray_workers = self.workers
+            if async_run_tensor_parallel_workers_only:
+                ray_workers = self.non_driver_workers
             ray_worker_outputs = [
                 worker.execute_method.remote(method, *worker_args,
                                              **worker_kwargs)
                 for (worker, worker_args, worker_kwargs
-                     ) in zip(self.workers, all_worker_args, all_worker_kwargs)
+                     ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
             ]
 
-        if async_run_remote_workers_only:
+        if async_run_tensor_parallel_workers_only:
             # Just return futures
             return ray_worker_outputs
 
@@ -257,6 +288,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
             driver_worker_output = self.driver_worker.execute_method(
                 method, *driver_args, **driver_kwargs)
         else:
+            assert self.driver_dummy_worker is not None
             driver_worker_output = ray.get(
                 self.driver_dummy_worker.execute_method.remote(
                     method, *driver_args, **driver_kwargs))
@@ -297,8 +329,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
         # a dummy value for now. It will be fixed soon.
         with InputNode() as input_data:
             forward_dag = MultiOutputNode([
-                worker.execute_model_compiled_dag_remote.bind(input_data)
-                for worker in self.workers
+                worker.execute_model_compiled_dag_remote.
+                bind(  # type: ignore[attr-defined]
+                    input_data) for worker in self.workers
             ])
         return forward_dag.experimental_compile()
 
@@ -313,12 +346,32 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
         self,
         execute_model_req: Optional[ExecuteModelRequest] = None
     ) -> List[SamplerOutput]:
-        return await self.driver_exec_method("execute_model",
-                                             execute_model_req)
+
+        async def _run_task_with_lock(task, lock, *args, **kwargs):
+            async with lock:
+                return await task(*args, **kwargs)
+
+        tasks = []
+        tasks.append(
+            asyncio.create_task(
+                _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
+                                    "execute_model", execute_model_req)))
+        for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
+                                                start=1):
+            tasks.append(
+                asyncio.create_task(
+                    _run_task_with_lock(driver_worker.execute_method.remote,
+                                        self.pp_locks[pp_rank],
+                                        "execute_model", execute_model_req)))
+
+        results = await asyncio.gather(*tasks)
+
+        # Only the last PP stage has the final results.
+        return results[-1]
 
     async def _start_worker_execution_loop(self):
         coros = [
             worker.execute_method.remote("start_worker_execution_loop")
-            for worker in self.workers
+            for worker in self.non_driver_workers
         ]
         return await asyncio.gather(*coros)

+ 2 - 1
aphrodite/modeling/models/arctic.py

@@ -7,7 +7,7 @@ from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -423,6 +423,7 @@ class ArcticForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/baichuan.py

@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -335,6 +335,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/bloom.py

@@ -25,7 +25,7 @@ from transformers import BloomConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -285,6 +285,7 @@ class BloomForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 2 - 1
aphrodite/modeling/models/chatglm.py

@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -362,6 +362,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 2 - 1
aphrodite/modeling/models/commandr.py

@@ -30,7 +30,7 @@ from transformers import CohereConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -352,6 +352,7 @@ class CohereForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/dbrx.py

@@ -6,7 +6,7 @@ import torch.nn as nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -380,6 +380,7 @@ class DbrxForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 2 - 1
aphrodite/modeling/models/deepseek.py

@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -386,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/deepseek_v2.py

@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -475,6 +475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/falcon.py

@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -384,6 +384,7 @@ class FalconForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(
             input_ids,

+ 3 - 2
aphrodite/modeling/models/gemma.py

@@ -24,7 +24,7 @@ from transformers import GemmaConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -37,9 +37,9 @@ from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.models.interfaces import SupportsLoRA
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.modeling.models.interfaces import SupportsLoRA
 
 
 @lru_cache(maxsize=None)
@@ -332,6 +332,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/gpt2.py

@@ -25,7 +25,7 @@ from transformers import GPT2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -227,6 +227,7 @@ class GPT2LMHeadModel(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 2 - 1
aphrodite/modeling/models/gpt_bigcode.py

@@ -26,7 +26,7 @@ from transformers import GPTBigCodeConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -246,6 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 2 - 1
aphrodite/modeling/models/gpt_j.py

@@ -24,7 +24,7 @@ from transformers import GPTJConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -238,6 +238,7 @@ class GPTJForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 2 - 1
aphrodite/modeling/models/gpt_neox.py

@@ -24,7 +24,7 @@ from transformers import GPTNeoXConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -250,6 +250,7 @@ class GPTNeoXForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
                                       attn_metadata)

+ 2 - 1
aphrodite/modeling/models/internlm2.py

@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -262,6 +262,7 @@ class InternLM2ForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: IntermediateTensors,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/jais.py

@@ -27,7 +27,7 @@ from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -288,6 +288,7 @@ class JAISLMHeadModel(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 72 - 28
aphrodite/modeling/models/llama.py

@@ -21,7 +21,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only LLaMA model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
 
 import torch
 from torch import nn
@@ -29,9 +29,10 @@ from transformers import LlamaConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.utils import is_hip, print_warning_once
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
+from aphrodite.distributed import (get_pp_group, get_pp_indices,
+                                   get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -263,12 +264,20 @@ class LlamaModel(nn.Module):
             config.hidden_size,
             org_num_embeddings=config.vocab_size,
         )
-        self.layers = nn.ModuleList([
-            LlamaDecoderLayer(config=config,
-                              cache_config=cache_config,
-                              quant_config=quant_config)
-            for idx in range(config.num_hidden_layers)
-        ])
+        self.start_layer, self.end_layer = get_pp_indices(
+            config.num_hidden_layers,
+            get_pp_group().rank_in_group,
+            get_pp_group().world_size)
+        self.layers = nn.ModuleList(
+            [nn.Identity() for _ in range(self.start_layer)] + [
+                LlamaDecoderLayer(config=config,
+                                  cache_config=cache_config,
+                                  quant_config=quant_config)
+                for _ in range(self.start_layer, self.end_layer)
+            ] + [
+                nn.Identity()
+                for _ in range(self.end_layer, config.num_hidden_layers)
+            ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
     def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -280,22 +289,35 @@ class LlamaModel(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors],
         inputs_embeds: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        if inputs_embeds is not None:
-            hidden_states = inputs_embeds
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        if get_pp_group().is_first_rank:
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.get_input_embeddings(input_ids)
+            residual = None
         else:
-            hidden_states = self.get_input_embeddings(input_ids)
-        residual = None
-        for i in range(len(self.layers)):
+            assert intermediate_tensors is not None
+            hidden_states = intermediate_tensors["hidden_states"]
+            residual = intermediate_tensors["residual"]
+
+        for i in range(self.start_layer, self.end_layer):
             layer = self.layers[i]
             hidden_states, residual = layer(
                 positions,
                 hidden_states,
-                kv_caches[i],
+                kv_caches[i - self.start_layer],
                 attn_metadata,
                 residual,
             )
+
+        if not get_pp_group().is_last_rank:
+            return IntermediateTensors({
+                "hidden_states": hidden_states,
+                "residual": residual
+            })
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
@@ -372,10 +394,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
-    ) -> torch.Tensor:
-        hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata)
-        return hidden_states
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        model_output = self.model(input_ids, positions, kv_caches,
+                                  attn_metadata, intermediate_tensors)
+        return model_output
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
@@ -391,6 +414,20 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
+    def make_empty_intermediate_tensors(
+            self, batch_size: int, dtype: torch.dtype,
+            device: torch.device) -> IntermediateTensors:
+        return IntermediateTensors({
+            "hidden_states":
+            torch.zeros((batch_size, self.config.hidden_size),
+                        dtype=dtype,
+                        device=device),
+            "residual":
+            torch.zeros((batch_size, self.config.hidden_size),
+                        dtype=dtype,
+                        device=device),
+        })
+
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
@@ -416,9 +453,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                     continue
-                param = params_dict[name]
-                weight_loader = param.weight_loader
-                weight_loader(param, loaded_weight, shard_id)
+                try:
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    weight_loader(param, loaded_weight, shard_id)
+                except KeyError:
+                    pass
                 break
             else:
                 # Skip loading extra bias for GPTQ models.
@@ -437,10 +477,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
                         continue
                     else:
                         name = remapped_kv_scale_name
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+                try:
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)
+                except KeyError:
+                    pass
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
@@ -452,7 +495,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
                 quantization_param_path, tp_rank, tp_size,
                 self.config.num_hidden_layers,
                 self.config.__class__.model_type):
-            layer_self_attn = self.model.layers[layer_idx].self_attn
+            if not isinstance(self.model.layers[layer_idx], nn.Identity):
+                layer_self_attn = self.model.layers[layer_idx].self_attn
 
             if is_hip():
                 # The scaling factor convention we are assuming is

+ 4 - 2
aphrodite/modeling/models/llava.py

@@ -6,10 +6,10 @@ from transformers import CLIPVisionConfig, LlavaConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, VisionLanguageConfig
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.inputs import INPUT_REGISTRY, InputContext
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.quantization.base_config import (QuantizationConfig)
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
@@ -17,7 +17,7 @@ from aphrodite.modeling.models.clip import CLIPVisionModel
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.quantization.base_config import QuantizationConfig
 
 from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
 from .interfaces import SupportsVision
@@ -202,6 +202,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         **kwargs: object,
     ) -> SamplerOutput:
         """Run forward pass for LLaVA-1.5.
@@ -247,6 +248,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
                                             positions,
                                             kv_caches,
                                             attn_metadata,
+                                            None,
                                             inputs_embeds=inputs_embeds)
 
         return hidden_states

+ 3 - 1
aphrodite/modeling/models/llava_next.py

@@ -11,7 +11,7 @@ from typing_extensions import NotRequired
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, VisionLanguageConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.inputs import INPUT_REGISTRY, InputContext
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
@@ -373,6 +373,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         **kwargs: object,
     ) -> SamplerOutput:
         """Run forward pass for LlaVA-NeXT.
@@ -427,6 +428,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
                                             positions,
                                             kv_caches,
                                             attn_metadata,
+                                            None,
                                             inputs_embeds=inputs_embeds)
 
         return hidden_states

+ 2 - 1
aphrodite/modeling/models/minicpm.py

@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -458,6 +458,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/mixtral.py

@@ -30,7 +30,7 @@ from transformers import MixtralConfig
 from aphrodite import _custom_ops as ops
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -529,6 +529,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/mixtral_quant.py

@@ -31,7 +31,7 @@ from transformers import MixtralConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -353,6 +353,7 @@ class MixtralForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/mpt.py

@@ -8,7 +8,7 @@ import torch.nn as nn
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -272,6 +272,7 @@ class MPTForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 2 - 1
aphrodite/modeling/models/olmo.py

@@ -29,7 +29,7 @@ from transformers import OlmoConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -300,6 +300,7 @@ class OlmoForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(
             input_ids=input_ids,

+ 2 - 1
aphrodite/modeling/models/opt.py

@@ -25,7 +25,7 @@ from transformers import OPTConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -303,6 +303,7 @@ class OPTForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/orion.py

@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -268,6 +268,7 @@ class OrionForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/phi.py

@@ -43,7 +43,7 @@ from transformers import PhiConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -274,6 +274,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/phi3_small.py

@@ -7,7 +7,7 @@ from transformers.configuration_utils import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -411,6 +411,7 @@ class Phi3SmallForCausalLM(nn.Module):
         positions: Optional[torch.LongTensor],
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         output_hidden_states = self.model(
             input_ids=input_ids,

+ 8 - 3
aphrodite/modeling/models/phi3v.py

@@ -25,7 +25,7 @@ from transformers import CLIPVisionConfig, PretrainedConfig
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, VisionLanguageConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.inputs import INPUT_REGISTRY, InputContext
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
@@ -378,9 +378,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
 
         return None
 
-    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
+    def forward(self,
+                input_ids: torch.Tensor,
+                positions: torch.Tensor,
                 kv_caches: List[torch.Tensor],
-                attn_metadata: AttentionMetadata, **kwargs: object):
+                attn_metadata: AttentionMetadata,
+                intermediate_tensors: Optional[IntermediateTensors] = None,
+                **kwargs: object):
         image_input = self._parse_and_validate_image_input(**kwargs)
 
         if image_input is not None:
@@ -395,6 +399,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
                                    positions,
                                    kv_caches,
                                    attn_metadata,
+                                   intermediate_tensors,
                                    inputs_embeds=inputs_embeds)
 
         return hidden_states

+ 2 - 1
aphrodite/modeling/models/qwen.py

@@ -12,7 +12,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -243,6 +243,7 @@ class QWenLMHeadModel(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
                                          attn_metadata)

+ 2 - 1
aphrodite/modeling/models/qwen2.py

@@ -30,7 +30,7 @@ from transformers import Qwen2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -326,6 +326,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/qwen2_moe.py

@@ -31,7 +31,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -396,6 +396,7 @@ class Qwen2MoeForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/stablelm.py

@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -249,6 +249,7 @@ class StablelmForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/starcoder2.py

@@ -26,7 +26,7 @@ from transformers import Starcoder2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -261,6 +261,7 @@ class Starcoder2ForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 2 - 1
aphrodite/modeling/models/xverse.py

@@ -28,7 +28,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
-from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -316,6 +316,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
     ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    attn_metadata)

+ 3 - 0
aphrodite/processing/block_manager_v1.py

@@ -472,6 +472,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
     def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
         # NOTE: fork does not allocate a new physical block.
         # Thus, it is always safe from OOM.
+        if parent_seq.seq_id not in self.block_tables:
+            # Parent sequence has either been freed or never existed.
+            return
         src_block_table = self.block_tables[parent_seq.seq_id]
         self.block_tables[child_seq.seq_id] = src_block_table.copy()
         # When using a sliding window, blocks will be eventually reused.

+ 3 - 0
aphrodite/processing/block_manager_v2.py

@@ -315,6 +315,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
             computed_seq_block_ids)  # type: ignore
 
     def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
+        if parent_seq.seq_id not in self.block_tables:
+            # Parent sequence has either been freed or never existed.
+            return
         src_block_table = self.block_tables[parent_seq.seq_id]
         self.block_tables[child_seq.seq_id] = src_block_table.fork()
 

+ 11 - 2
aphrodite/processing/scheduler.py

@@ -255,6 +255,7 @@ class Scheduler:
         scheduler_config: SchedulerConfig,
         cache_config: CacheConfig,
         lora_config: Optional[LoRAConfig],
+        pipeline_parallel_size: int = 1,
     ) -> None:
         self.scheduler_config = scheduler_config
         self.cache_config = cache_config
@@ -272,11 +273,19 @@ class Scheduler:
         BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
             version)
 
+        num_gpu_blocks = cache_config.num_gpu_blocks
+        if num_gpu_blocks:
+            num_gpu_blocks //= pipeline_parallel_size
+
+        num_cpu_blocks = cache_config.num_cpu_blocks
+        if num_cpu_blocks:
+            num_cpu_blocks //= pipeline_parallel_size
+
         # Create the block space manager.
         self.block_manager = BlockSpaceManagerImpl(
             block_size=self.cache_config.block_size,
-            num_gpu_blocks=self.cache_config.num_gpu_blocks,
-            num_cpu_blocks=self.cache_config.num_cpu_blocks,
+            num_gpu_blocks=num_gpu_blocks,
+            num_cpu_blocks=num_cpu_blocks,
             sliding_window=self.cache_config.sliding_window,
             enable_caching=self.cache_config.enable_prefix_caching)
 

+ 10 - 5
aphrodite/spec_decode/draft_model_runner.py

@@ -5,7 +5,8 @@ import torch
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+                                       SequenceGroupMetadata)
 from aphrodite.task_handler.model_runner import (
     ModelInputForGPUWithSamplingMetadata, ModelRunner)
 
@@ -70,9 +71,9 @@ class TP1DraftModelRunner(ModelRunner):
             List[SequenceGroupMetadata]] = None
 
     def prepare_model_input(
-        self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
-    ) -> ModelInputForGPUWithSamplingMetadata:
+            self,
+            seq_group_metadata_list: List[SequenceGroupMetadata],
+            virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
         """A temporary solution that caches the seq_group_metadata_list
         for multi-step execution.
         TODO: In-place update model_input and remove this function.
@@ -111,6 +112,7 @@ class TP1DraftModelRunner(ModelRunner):
         self,
         model_input: ModelInputForGPUWithSamplingMetadata,
         kv_caches: List[torch.Tensor],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         num_steps: int = 1,
     ) -> Optional[List[SamplerOutput]]:
         # Since we do not broadcast data inside execute_model anymore,
@@ -126,6 +128,7 @@ class TP1DraftModelRunner(ModelRunner):
             self.set_active_loras(model_input.lora_requests,
                                   model_input.lora_mapping)
 
+        virtual_engine = model_input.virtual_engine
         outputs: List[SamplerOutput] = []
         for step in range(num_steps):
             # Currently cuda graph is only supported by the decode phase.
@@ -135,7 +138,8 @@ class TP1DraftModelRunner(ModelRunner):
             if prefill_meta is None and decode_meta.use_cuda_graph:
                 assert model_input.input_tokens is not None
                 graph_batch_size = model_input.input_tokens.shape[0]
-                model_executable = self.graph_runners[graph_batch_size]
+                model_executable = (
+                    self.graph_runners[virtual_engine][graph_batch_size])
             else:
                 model_executable = self.model
 
@@ -145,6 +149,7 @@ class TP1DraftModelRunner(ModelRunner):
                 positions=model_input.input_positions,
                 kv_caches=kv_caches,
                 attn_metadata=model_input.attn_metadata,
+                intermediate_tensors=intermediate_tensors,
                 **multi_modal_kwargs,
             )
 

+ 4 - 0
aphrodite/task_handler/cache_engine.py

@@ -36,7 +36,11 @@ class CacheEngine:
 
         self.block_size = cache_config.block_size
         self.num_gpu_blocks = cache_config.num_gpu_blocks
+        if self.num_gpu_blocks:
+            self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
         self.num_cpu_blocks = cache_config.num_cpu_blocks
+        if self.num_cpu_blocks:
+            self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
 
         if cache_config.cache_dtype == "auto":
             self.dtype = model_config.dtype

+ 4 - 1
aphrodite/task_handler/cpu_model_runner.py

@@ -9,7 +9,8 @@ from aphrodite.attention import AttentionMetadata, get_attn_backend
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+                                       SequenceGroupMetadata)
 from aphrodite.common.utils import make_tensor_with_pad
 from aphrodite.modeling import SamplingMetadata
 from aphrodite.modeling.model_loader import get_model
@@ -312,6 +313,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
     def prepare_model_input(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
+        virtual_engine: int = 0,
     ) -> CPUModelInput:
         multi_modal_kwargs = None
         # NOTE: We assume that all sequences in the group are all prompts or
@@ -348,6 +350,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
         self,
         model_input: CPUModelInput,
         kv_caches: List[torch.Tensor],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         num_steps: int = 1,
     ) -> Optional[List[SamplerOutput]]:
         if num_steps > 1:

+ 24 - 14
aphrodite/task_handler/cpu_worker.py

@@ -165,8 +165,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
             is_driver_worker=is_driver_worker)
         # Uninitialized cache engine. Will be initialized by
         # initialize_cache.
-        self.cache_engine: CPUCacheEngine
-        self.cpu_cache: List[torch.Tensor]
+        self.cache_engine: List[CPUCacheEngine]
+        self.cpu_cache: List[List[torch.Tensor]]
 
     def init_device(self) -> None:
         self.init_distributed_environment()
@@ -241,25 +241,32 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
                 "when initializing the engine.")
 
     def _init_cache_engine(self) -> None:
-        self.cache_engine = CPUCacheEngine(self.cache_config,
-                                           self.model_config,
-                                           self.parallel_config,
-                                           self.device_config)
-        self.cpu_cache = self.cache_engine.cpu_cache
-        self.model_runner.block_size = self.cache_engine.block_size
-
-        assert self.cpu_cache is not None
+        self.cache_engine = [
+            CPUCacheEngine(self.cache_config, self.model_config,
+                           self.parallel_config, self.device_config)
+            for _ in range(self.parallel_config.pipeline_parallel_size)
+        ]
+        self.cpu_cache = [
+            self.cache_engine[ve].cpu_cache
+            for ve in range(self.parallel_config.pipeline_parallel_size)
+        ]
+        self.model_runner.block_size = self.cache_engine[0].block_size
+
+        assert all(
+            self.cpu_cache[ve] is not None
+            for ve in range(self.parallel_config.pipeline_parallel_size))
 
         # Populate the cache to warmup the memory
-        for layer_cache in self.cpu_cache:
-            layer_cache.fill_(0)
+        for ve in range(self.parallel_config.pipeline_parallel_size):
+            for layer_cache in self.cpu_cache[ve]:
+                layer_cache.fill_(0)
 
     @property
     def do_metadata_broadcast(self) -> bool:
         return self.parallel_config.tensor_parallel_size > 1
 
     @property
-    def kv_cache(self) -> Optional[List[torch.Tensor]]:
+    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
         return self.cpu_cache
 
     def execute_worker(
@@ -268,12 +275,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
     ) -> None:
         if (worker_input.blocks_to_copy is not None
                 and worker_input.blocks_to_copy.numel() > 0):
-            self.cache_engine.copy(worker_input.blocks_to_copy)
+            self.cache_engine[worker_input.virtual_engine].copy(
+                worker_input.blocks_to_copy)
 
     @torch.inference_mode()
     def prepare_worker_input(
             self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
         assert execute_model_req is not None
+        virtual_engine = execute_model_req.virtual_engine
         num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
         blocks_to_copy = execute_model_req.blocks_to_copy
         blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
@@ -284,6 +293,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
         return WorkerInput(
             num_seq_groups=num_seq_groups,
             blocks_to_copy=blocks_to_copy,
+            virtual_engine=virtual_engine,
         )
 
     def init_distributed_environment(self) -> None:

+ 8 - 3
aphrodite/task_handler/embedding_model_runner.py

@@ -7,8 +7,8 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.common.pooling_params import PoolingParams
-from aphrodite.common.sequence import (PoolerOutput, SequenceData,
-                                       SequenceGroupMetadata)
+from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
+                                       SequenceData, SequenceGroupMetadata)
 from aphrodite.modeling.pooling_metadata import PoolingMetadata
 from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
                                                  ModelInputForGPU)
@@ -56,11 +56,13 @@ class EmbeddingModelRunner(
         self,
         model_input: ModelInputForGPUWithPoolingMetadata,
         kv_caches: List[torch.Tensor],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         num_steps: int = 1,
     ) -> Optional[List[PoolerOutput]]:
         if num_steps > 1:
             raise ValueError(
                 "EmbeddingModelRunner does not support multi-step execution.")
+
         if self.lora_config:
             assert model_input.lora_requests is not None
             assert model_input.lora_mapping is not None
@@ -71,10 +73,12 @@ class EmbeddingModelRunner(
         assert model_input.attn_metadata is not None
         prefill_meta = model_input.attn_metadata.prefill_metadata
         decode_meta = model_input.attn_metadata.decode_metadata
+        virtual_engine = model_input.virtual_engine
         if prefill_meta is None and decode_meta.use_cuda_graph:
             assert model_input.input_tokens is not None
             graph_batch_size = model_input.input_tokens.shape[0]
-            model_executable = self.graph_runners[graph_batch_size]
+            model_executable = self.graph_runners[virtual_engine][
+                graph_batch_size]
         else:
             model_executable = self.model
 
@@ -113,6 +117,7 @@ class EmbeddingModelRunner(
     def prepare_model_input(
         self,
         seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+        virtual_engine: int = 0,
     ) -> ModelInputForGPUWithPoolingMetadata:
         assert seq_group_metadata_list is not None
         model_input = self._prepare_model_input_tensors(

+ 201 - 125
aphrodite/task_handler/model_runner.py

@@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
 
 import numpy as np
 import torch
+import torch.distributed
 import torch.nn as nn
 from loguru import logger
 
@@ -27,11 +28,13 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+                                       SequenceGroupMetadata)
 from aphrodite.common.utils import (CudaMemoryProfiler,
                                     get_kv_cache_torch_dtype, is_hip,
                                     is_pin_memory_available,
                                     make_tensor_with_pad)
+from aphrodite.distributed import get_pp_group
 from aphrodite.distributed.parallel_state import (
     get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
     graph_capture)
@@ -83,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
     lora_requests: Optional[Set[LoRARequest]] = None
     attn_metadata: Optional["AttentionMetadata"] = None
     multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
+    virtual_engine: int = 0
 
     def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
         tensor_dict = {
@@ -91,6 +95,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
             "lora_requests": self.lora_requests,
             "lora_mapping": self.lora_mapping,
             "multi_modal_kwargs": self.multi_modal_kwargs,
+            "virtual_engine": self.virtual_engine,
         }
         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
         return tensor_dict
@@ -124,6 +129,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
             "lora_requests": self.lora_requests,
             "lora_mapping": self.lora_mapping,
             "multi_modal_kwargs": self.multi_modal_kwargs,
+            "virtual_engine": self.virtual_engine,
         }
         _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
         _add_sampling_metadata_broadcastable_dict(tensor_dict,
@@ -181,7 +187,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         self.sliding_window = model_config.get_sliding_window()
         self.block_size = cache_config.block_size
         self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
-        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
+
+        self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
+            {} for _ in range(self.parallel_config.pipeline_parallel_size)
+        ]
         self.graph_memory_pool: Optional[Tuple[
             int, int]] = None  # Set during graph capture.
         # When using CUDA graph, the input block tables must be padded to
@@ -806,9 +815,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
             max_num_seqs = min(
                 max_num_seqs,
                 int(max_num_batched_tokens / vlm_config.image_feature_size))
+        batch_size = 0
         for group_id in range(max_num_seqs):
             seq_len = (max_num_batched_tokens // max_num_seqs +
                        (group_id < max_num_batched_tokens % max_num_seqs))
+            batch_size += seq_len
 
             seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
                 .dummy_data_for_profiling(model_config, seq_len)
@@ -830,7 +841,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         num_layers = self.model_config.get_num_layers(self.parallel_config)
         kv_caches = [None] * num_layers
         model_input = self.prepare_model_input(seqs)
-        self.execute_model(model_input, kv_caches)
+        intermediate_tensors = None
+        if not get_pp_group().is_first_rank:
+            intermediate_tensors = self.model.make_empty_intermediate_tensors(
+                batch_size=batch_size,
+                dtype=self.model_config.dtype,
+                device=self.device)
+        self.execute_model(model_input, kv_caches, intermediate_tensors)
         torch.cuda.synchronize()
         return
 
@@ -866,7 +883,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         return self.lora_manager.list_loras()
 
     @torch.inference_mode()
-    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
+    def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
         """Cuda graph capture a model.
 
         Note that CUDA graph's performance gain is negligible if number
@@ -899,10 +916,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         slot_mapping.fill_(_PAD_SLOT_ID)
         seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
         block_tables = torch.from_numpy(self.graph_block_tables).cuda()
+        intermediate_inputs = None
+        if not get_pp_group().is_first_rank:
+            intermediate_inputs = self.model.make_empty_intermediate_tensors(
+                batch_size=max_batch_size,
+                dtype=self.model_config.dtype,
+                device=self.device)
 
         # Prepare buffer for outputs. These will be reused for all batch sizes.
         # It will be filled after the first graph capture.
-        hidden_states: Optional[torch.Tensor] = None
+        hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
+            None
+        ] * self.parallel_config.pipeline_parallel_size
 
         graph_batch_size = _get_graph_batch_size(
             self.scheduler_config.max_num_seqs)
@@ -931,109 +956,120 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         with graph_capture() as graph_capture_context:
             # NOTE: Capturing the largest batch size first may help reduce the
             # memory usage of CUDA graph.
-            for batch_size in reversed(batch_size_capture_list):
-                if self.attn_backend.get_name() == "flashinfer":
-                    indptr_buffer = indptr_buffer[:batch_size + 1]
-                    last_page_len_buffer = last_page_len_buffer[:batch_size]
-
-                    num_qo_heads = self.model_config.get_num_attention_heads(
-                        self.parallel_config)
-                    num_kv_heads = self.model_config.get_num_kv_heads(
-                        self.parallel_config)
-                    if num_qo_heads // num_kv_heads >= 4:
-                        use_tensor_cores = True
+            for virtual_engine in range(
+                    self.parallel_config.pipeline_parallel_size):
+                for batch_size in reversed(batch_size_capture_list):
+                    if self.attn_backend.get_name() == "flashinfer":
+                        indptr_buffer = indptr_buffer[:batch_size + 1]
+                        last_page_len_buffer = last_page_len_buffer[:
+                                                                    batch_size]
+
+                        num_qo_heads = (
+                            self.model_config.get_num_attention_heads(
+                                self.parallel_config))
+                        num_kv_heads = self.model_config.get_num_kv_heads(
+                            self.parallel_config)
+                        if num_qo_heads // num_kv_heads >= 4:
+                            use_tensor_cores = True
+                        else:
+                            use_tensor_cores = False
+                        decode_wrapper = \
+                            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
+                            decode_workspace_buffer, indptr_buffer,
+                            indices_buffer, last_page_len_buffer, "NHD",
+                            use_tensor_cores)
+                        kv_cache_dtype = get_kv_cache_torch_dtype(
+                            self.kv_cache_dtype, self.model_config.dtype)
+
+                        paged_kv_indptr_tensor_host = torch.arange(
+                            0, batch_size + 1, dtype=torch.int32)
+                        paged_kv_indices_tensor_host = torch.arange(
+                            0, batch_size, dtype=torch.int32)
+                        paged_kv_last_page_len_tensor_host = torch.full(
+                            (batch_size, ), self.block_size, dtype=torch.int32)
+                        query_start_loc_host = torch.arange(0,
+                                                            batch_size + 1,
+                                                            dtype=torch.int32)
+
+                        attn_metadata = self.attn_backend.make_metadata(
+                            num_prefills=0,
+                            slot_mapping=slot_mapping[:batch_size],
+                            num_prefill_tokens=0,
+                            num_decode_tokens=batch_size,
+                            max_prefill_seq_len=0,
+                            block_tables=block_tables,
+                            paged_kv_indptr=paged_kv_indptr_tensor_host,
+                            paged_kv_indices=paged_kv_indices_tensor_host,
+                            paged_kv_last_page_len=
+                            paged_kv_last_page_len_tensor_host,
+                            num_qo_heads=num_qo_heads,
+                            num_kv_heads=num_kv_heads,
+                            head_dim=self.model_config.get_head_size(),
+                            page_size=self.block_size,
+                            seq_start_loc=None,
+                            query_start_loc=query_start_loc_host,
+                            device=self.device,
+                            data_type=kv_cache_dtype,
+                            use_cuda_graph=True,
+                            decode_wrapper=decode_wrapper,
+                            prefill_wrapper=None)
+                        attn_metadata.begin_forward()
                     else:
-                        use_tensor_cores = False
-                    decode_wrapper = \
-                        CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
-                        decode_workspace_buffer, indptr_buffer, indices_buffer,
-                        last_page_len_buffer, "NHD", use_tensor_cores)
-                    kv_cache_dtype = get_kv_cache_torch_dtype(
-                        self.kv_cache_dtype, self.model_config.dtype)
-
-                    paged_kv_indptr_tensor_host = torch.arange(
-                        0, batch_size + 1, dtype=torch.int32)
-                    paged_kv_indices_tensor_host = torch.arange(
-                        0, batch_size, dtype=torch.int32)
-                    paged_kv_last_page_len_tensor_host = torch.full(
-                        (batch_size, ), self.block_size, dtype=torch.int32)
-                    query_start_loc_host = torch.arange(0,
-                                                        batch_size + 1,
-                                                        dtype=torch.int32)
-
-                    attn_metadata = self.attn_backend.make_metadata(
-                        num_prefills=0,
-                        slot_mapping=slot_mapping[:batch_size],
-                        num_prefill_tokens=0,
-                        num_decode_tokens=batch_size,
-                        max_prefill_seq_len=0,
-                        block_tables=block_tables,
-                        paged_kv_indptr=paged_kv_indptr_tensor_host,
-                        paged_kv_indices=paged_kv_indices_tensor_host,
-                        paged_kv_last_page_len=
-                        paged_kv_last_page_len_tensor_host,
-                        num_qo_heads=num_qo_heads,
-                        num_kv_heads=num_kv_heads,
-                        head_dim=self.model_config.get_head_size(),
-                        page_size=self.block_size,
-                        seq_start_loc=None,
-                        query_start_loc=query_start_loc_host,
-                        device=self.device,
-                        data_type=kv_cache_dtype,
-                        use_cuda_graph=True,
-                        decode_wrapper=decode_wrapper,
-                        prefill_wrapper=None)
-                    attn_metadata.begin_forward()
-                else:
-                    attn_metadata = self.attn_backend.make_metadata(
-                        num_prefills=0,
-                        num_prefill_tokens=0,
-                        num_decode_tokens=batch_size,
-                        slot_mapping=slot_mapping[:batch_size],
-                        seq_lens=None,
-                        seq_lens_tensor=seq_lens[:batch_size],
-                        max_query_len=None,
-                        max_prefill_seq_len=0,
-                        max_decode_seq_len=self.max_seq_len_to_capture,
-                        query_start_loc=None,
-                        seq_start_loc=None,
-                        context_lens_tensor=None,
-                        block_tables=block_tables[:batch_size],
-                        use_cuda_graph=True,
+                        attn_metadata = self.attn_backend.make_metadata(
+                            num_prefills=0,
+                            num_prefill_tokens=0,
+                            num_decode_tokens=batch_size,
+                            slot_mapping=slot_mapping[:batch_size],
+                            seq_lens=None,
+                            seq_lens_tensor=seq_lens[:batch_size],
+                            max_query_len=None,
+                            max_prefill_seq_len=0,
+                            max_decode_seq_len=self.max_seq_len_to_capture,
+                            query_start_loc=None,
+                            seq_start_loc=None,
+                            context_lens_tensor=None,
+                            block_tables=block_tables[:batch_size],
+                            use_cuda_graph=True,
+                        )
+
+                    if self.lora_config:
+                        lora_mapping = LoRAMapping(
+                            [0] * batch_size,
+                            [0] * batch_size,
+                        )
+                        self.set_active_loras(set(), lora_mapping)
+
+                    graph_runner = CUDAGraphRunner(
+                        self.model, self.attn_backend.get_name())
+
+                    if self.attn_backend.get_name() == "flashinfer":
+                        graph_runner.flashinfer_indptr_buffer = indptr_buffer
+                        graph_runner.flashinfer_indices_buffer = indices_buffer
+                        graph_runner.flashinfer_last_page_len_buffer = \
+                            last_page_len_buffer
+                        graph_runner.flashinfer_decode_workspace_buffer = \
+                                decode_workspace_buffer
+                        graph_runner.flashinfer_decode_wrapper = \
+                            decode_wrapper
+
+                    graph_runner.capture(
+                        input_tokens[:batch_size],
+                        input_positions[:batch_size],
+                        hidden_or_intermediate_states[
+                            virtual_engine]  # type: ignore
+                        [:batch_size]
+                        if hidden_or_intermediate_states[virtual_engine]
+                        is not None else None,
+                        intermediate_inputs[:batch_size]
+                        if intermediate_inputs is not None else None,
+                        kv_caches[virtual_engine],
+                        attn_metadata,
+                        memory_pool=self.graph_memory_pool,
+                        stream=graph_capture_context.stream,
                     )
-
-                if self.lora_config:
-                    lora_mapping = LoRAMapping(
-                        [0] * batch_size,
-                        [0] * batch_size,
-                    )
-                    self.set_active_loras(set(), lora_mapping)
-
-                graph_runner = CUDAGraphRunner(self.model,
-                                               self.attn_backend.get_name())
-
-                if self.attn_backend.get_name() == "flashinfer":
-                    graph_runner.flashinfer_indptr_buffer = indptr_buffer
-                    graph_runner.flashinfer_indices_buffer = indices_buffer
-                    graph_runner.flashinfer_last_page_len_buffer = \
-                        last_page_len_buffer
-                    graph_runner.flashinfer_decode_workspace_buffer = \
-                            decode_workspace_buffer
-                    graph_runner.flashinfer_decode_wrapper = \
-                        decode_wrapper
-
-                graph_runner.capture(
-                    input_tokens[:batch_size],
-                    input_positions[:batch_size],
-                    hidden_states[:batch_size]
-                    if hidden_states is not None else None,
-                    kv_caches,
-                    attn_metadata,
-                    memory_pool=self.graph_memory_pool,
-                    stream=graph_capture_context.stream,
-                )
-                self.graph_memory_pool = graph_runner.graph.pool()
-                self.graph_runners[batch_size] = graph_runner
+                    self.graph_memory_pool = graph_runner.graph.pool()
+                    self.graph_runners[virtual_engine][batch_size] = (
+                        graph_runner)
 
         end_time = time.perf_counter()
         elapsed_time = end_time - start_time
@@ -1066,6 +1102,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
     def prepare_model_input(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
+        virtual_engine: int = 0,
     ) -> ModelInputForGPUWithSamplingMetadata:
         """Prepare the model input based on a given sequence group, including
         metadata for the sampling step.
@@ -1091,15 +1128,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
                      if seq_group_metadata_list else None)
         return dataclasses.replace(model_input,
                                    sampling_metadata=sampling_metadata,
-                                   is_prompt=is_prompt)
+                                   is_prompt=is_prompt,
+                                   virtual_engine=virtual_engine)
 
     @torch.inference_mode()
     def execute_model(
         self,
         model_input: ModelInputForGPUWithSamplingMetadata,
         kv_caches: List[torch.Tensor],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         num_steps: int = 1,
-    ) -> Optional[List[SamplerOutput]]:
+    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
         if num_steps > 1:
             raise ValueError("num_steps > 1 is not supported in ModelRunner")
 
@@ -1143,27 +1182,34 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
         assert model_input.attn_metadata is not None
         prefill_meta = model_input.attn_metadata.prefill_metadata
         decode_meta = model_input.attn_metadata.decode_metadata
+        # TODO: We can remove this once all
+        # virtual engines share the same kv cache.
+        virtual_engine = model_input.virtual_engine
         if prefill_meta is None and decode_meta.use_cuda_graph:
             assert model_input.input_tokens is not None
             graph_batch_size = model_input.input_tokens.shape[0]
-            model_executable = self.graph_runners[graph_batch_size]
+            model_executable = self.graph_runners[virtual_engine][
+                graph_batch_size]
         else:
             model_executable = self.model
 
         multi_modal_kwargs = model_input.multi_modal_kwargs or {}
-        hidden_states = model_executable(
+        hidden_or_intermediate_states = model_executable(
             input_ids=model_input.input_tokens,
             positions=model_input.input_positions,
             kv_caches=kv_caches,
             attn_metadata=model_input.attn_metadata,
+            intermediate_tensors=intermediate_tensors,
             **multi_modal_kwargs,
         )
 
-        # Compute the logits.
-        logits = self.model.compute_logits(hidden_states,
+        # Compute the logits in the last pipeline stage.
+        if not get_pp_group().is_last_rank:
+            return hidden_or_intermediate_states
+
+        logits = self.model.compute_logits(hidden_or_intermediate_states,
                                            model_input.sampling_metadata)
 
-        # Only perform sampling in the driver worker.
         if not self.is_driver_worker:
             return []
 
@@ -1178,9 +1224,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
             assert model_input.sampling_metadata is not None
             indices = model_input.sampling_metadata.selected_token_indices
             if model_input.is_prompt:
-                hidden_states = hidden_states.index_select(0, indices)
+                hidden_states = hidden_or_intermediate_states.index_select(
+                    0, indices)
             elif decode_meta.use_cuda_graph:
-                hidden_states = hidden_states[:len(indices)]
+                hidden_states = hidden_or_intermediate_states[:len(indices)]
+            else:
+                hidden_states = hidden_or_intermediate_states
 
             output.hidden_states = hidden_states
 
@@ -1214,13 +1263,15 @@ class CUDAGraphRunner:
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        hidden_states: Optional[torch.Tensor],
+        hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
+                                                      torch.Tensor]],
+        intermediate_inputs: Optional[IntermediateTensors],
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
         memory_pool: Optional[Tuple[int, int]],
         stream: torch.cuda.Stream,
         **kwargs,
-    ) -> torch.Tensor:
+    ) -> Union[torch.Tensor, IntermediateTensors]:
         assert self._graph is None
         # Run the model a few times without capturing the graph.
         # This is to make sure that the captured graph does not include the
@@ -1232,6 +1283,7 @@ class CUDAGraphRunner:
                 positions,
                 kv_caches,
                 attn_metadata,
+                intermediate_inputs,
                 **kwargs,
             )
         torch.cuda.synchronize()
@@ -1239,18 +1291,27 @@ class CUDAGraphRunner:
         # Capture the graph.
         self._graph = torch.cuda.CUDAGraph()
         with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
-            output_hidden_states = self.model(
+            output_hidden_or_intermediate_states = self.model(
                 input_ids,
                 positions,
                 kv_caches,
                 attn_metadata,
+                intermediate_inputs,
                 **kwargs,
             )
-            if hidden_states is not None:
-                hidden_states.copy_(output_hidden_states)
+            if hidden_or_intermediate_states is not None:
+                if get_pp_group().is_last_rank:
+                    hidden_or_intermediate_states.copy_(
+                        output_hidden_or_intermediate_states)
+                else:
+                    for key in hidden_or_intermediate_states.tensors:
+                        hidden_or_intermediate_states[key].copy_(
+                            output_hidden_or_intermediate_states[key])
             else:
-                hidden_states = output_hidden_states
-            del output_hidden_states
+                hidden_or_intermediate_states = (
+                    output_hidden_or_intermediate_states)
+
+            del output_hidden_or_intermediate_states
             # make sure `output_hidden_states` is deleted
             # in the graph's memory pool
             gc.collect()
@@ -1274,8 +1335,15 @@ class CUDAGraphRunner:
                 attn_metadata.decode_metadata.seq_lens_tensor,
                 "block_tables": attn_metadata.decode_metadata.block_tables,
             }
-        self.output_buffers = {"hidden_states": hidden_states}
-        return hidden_states
+        if intermediate_inputs is not None:
+            self.input_buffers.update(intermediate_inputs.tensors)
+        if get_pp_group().is_last_rank:
+            self.output_buffers = {
+                "hidden_states": hidden_or_intermediate_states
+            }
+        else:
+            self.output_buffers = hidden_or_intermediate_states
+        return hidden_or_intermediate_states
 
     def forward(
         self,
@@ -1283,6 +1351,7 @@ class CUDAGraphRunner:
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors],
         **kwargs,
     ) -> torch.Tensor:
         # KV caches are fixed tensors, so we don't need to copy them.
@@ -1299,11 +1368,18 @@ class CUDAGraphRunner:
                 non_blocking=True)
             self.input_buffers["block_tables"].copy_(
                 attn_metadata.decode_metadata.block_tables, non_blocking=True)
+        if intermediate_tensors is not None:
+            for key in intermediate_tensors.tensors:
+                self.input_buffers[key].copy_(intermediate_tensors[key],
+                                              non_blocking=True)
         # Run the graph.
         self.graph.replay()
 
         # Return the output tensor.
-        return self.output_buffers["hidden_states"]
+        if get_pp_group().is_last_rank:
+            return self.output_buffers["hidden_states"]
+
+        return self.output_buffers
 
     def __call__(self, *args, **kwargs):
         return self.forward(*args, **kwargs)

+ 4 - 1
aphrodite/task_handler/model_runner_base.py

@@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
 
 import torch
 
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+                                       SequenceGroupMetadata)
 
 if TYPE_CHECKING:
     from aphrodite.attention import AttentionMetadata
@@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]):
     def prepare_model_input(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
+        virtual_engine: int = 0,
     ) -> T:
         """
         Prepare the inputs to ModelRunnerBase.execute_model from an execution
@@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]):
         self,
         model_input: T,
         kv_caches: Optional[List[torch.Tensor]],
+        intermediate_tensors: Optional[IntermediateTensors],
         num_steps: int = 1,
     ) -> Optional[List[SamplerOutput]]:
         """

+ 4 - 1
aphrodite/task_handler/neuron_model_runner.py

@@ -7,7 +7,8 @@ from torch import nn
 
 from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig)
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+                                       SequenceGroupMetadata)
 from aphrodite.common.utils import (is_pin_memory_available,
                                     make_tensor_with_pad)
 from aphrodite.modeling import SamplingMetadata
@@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
     def prepare_model_input(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
+        virtual_engine: int = 0,
     ) -> ModelInputForNeuron:
         # NOTE: We assume that all sequences in the group are all prompts or
         # all decodes.
@@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
         self,
         model_input: ModelInputForNeuron,
         kv_caches: Optional[List[torch.Tensor]] = None,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         num_steps: int = 1,
     ) -> Optional[List[SamplerOutput]]:
         if num_steps > 1:

+ 1 - 1
aphrodite/task_handler/neuron_worker.py

@@ -81,7 +81,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
         return False
 
     @property
-    def kv_cache(self) -> Optional[List[torch.Tensor]]:
+    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
         return None
 
     @torch.inference_mode()

+ 27 - 18
aphrodite/task_handler/worker.py

@@ -62,8 +62,9 @@ class Worker(LocalOrDistributedWorkerBase):
         self.lora_config = lora_config
         self.load_config = load_config
         self.is_driver_worker = is_driver_worker
-        if self.is_driver_worker:
-            assert self.rank == 0, "The driver worker must have rank 0."
+        if parallel_config and is_driver_worker:
+            assert rank % parallel_config.tensor_parallel_size == 0, \
+                   "Driver worker should be rank 0 of tensor parallel group."
 
         if self.model_config.trust_remote_code:
             # note: lazy import to avoid importing torch before initializing
@@ -102,9 +103,9 @@ class Worker(LocalOrDistributedWorkerBase):
         )
         # Uninitialized cache engine. Will be initialized by
         # initialize_cache.
-        self.cache_engine: CacheEngine
+        self.cache_engine: List[CacheEngine]
         # Initialize gpu_cache as embedding models don't initialize kv_caches
-        self.gpu_cache: Optional[List[torch.tensor]] = None
+        self.gpu_cache: Optional[List[List[torch.tensor]]] = None
 
     def init_device(self) -> None:
         if self.device_config.device.type == "cuda":
@@ -220,10 +221,15 @@ class Worker(LocalOrDistributedWorkerBase):
 
     def _init_cache_engine(self):
         assert self.cache_config.num_gpu_blocks is not None
-        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
-                                        self.parallel_config,
-                                        self.device_config)
-        self.gpu_cache = self.cache_engine.gpu_cache
+        self.cache_engine = [
+            CacheEngine(self.cache_config, self.model_config,
+                        self.parallel_config, self.device_config)
+            for _ in range(self.parallel_config.pipeline_parallel_size)
+        ]
+        self.gpu_cache = [
+            self.cache_engine[ve].gpu_cache
+            for ve in range(self.parallel_config.pipeline_parallel_size)
+        ]
 
     def _warm_up_model(self) -> None:
         if not self.model_config.enforce_eager:
@@ -237,12 +243,13 @@ class Worker(LocalOrDistributedWorkerBase):
         return self.parallel_config.tensor_parallel_size > 1
 
     @property
-    def kv_cache(self) -> Optional[List[torch.Tensor]]:
+    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
         return self.gpu_cache
 
     @torch.inference_mode()
     def prepare_worker_input(
             self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
+        virtual_engine = execute_model_req.virtual_engine
         num_seq_groups = len(execute_model_req.seq_group_metadata_list)
         # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
         # they contain parameters to launch cudamemcpyasync.
@@ -259,25 +266,27 @@ class Worker(LocalOrDistributedWorkerBase):
                                       device=self.device,
                                       dtype=torch.int64).view(-1, 2)
 
-        return WorkerInput(
-            num_seq_groups=num_seq_groups,
-            blocks_to_swap_in=blocks_to_swap_in,
-            blocks_to_swap_out=blocks_to_swap_out,
-            blocks_to_copy=blocks_to_copy,
-        )
+        return WorkerInput(num_seq_groups=num_seq_groups,
+                           blocks_to_swap_in=blocks_to_swap_in,
+                           blocks_to_swap_out=blocks_to_swap_out,
+                           blocks_to_copy=blocks_to_copy,
+                           virtual_engine=virtual_engine)
 
     @torch.inference_mode()
     def execute_worker(self, worker_input: WorkerInput) -> None:
+        virtual_engine = worker_input.virtual_engine
         # Issue cache operations.
         if (worker_input.blocks_to_swap_in is not None
                 and worker_input.blocks_to_swap_in.numel() > 0):
-            self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
+            self.cache_engine[virtual_engine].swap_in(
+                worker_input.blocks_to_swap_in)
         if (worker_input.blocks_to_swap_out is not None
                 and worker_input.blocks_to_swap_out.numel() > 0):
-            self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
+            self.cache_engine[virtual_engine].swap_out(
+                worker_input.blocks_to_swap_out)
         if (worker_input.blocks_to_copy is not None
                 and worker_input.blocks_to_copy.numel() > 0):
-            self.cache_engine.copy(worker_input.blocks_to_copy)
+            self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
 
     def add_lora(self, lora_request: LoRARequest) -> bool:
         return self.model_runner.add_lora(lora_request)

+ 31 - 9
aphrodite/task_handler/worker_base.py

@@ -7,10 +7,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
 import torch
 from loguru import logger
 
-from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
+from aphrodite.common.sequence import (ExecuteModelRequest,
+                                       IntermediateTensors, SamplerOutput)
 from aphrodite.common.utils import (enable_trace_function_call_for_thread,
                                     is_hip, update_environment_variables)
-from aphrodite.distributed import broadcast_tensor_dict
+from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
 from aphrodite.lora.request import LoRARequest
 from aphrodite.task_handler.model_runner_base import (ModelRunnerBase,
                                                       ModelRunnerInputBase)
@@ -123,6 +124,7 @@ class WorkerInput:
     blocks_to_swap_in: Optional[torch.Tensor] = None
     blocks_to_swap_out: Optional[torch.Tensor] = None
     blocks_to_copy: Optional[torch.Tensor] = None
+    virtual_engine: int = 0
 
     @classmethod
     def from_broadcasted_tensor_dict(
@@ -138,6 +140,7 @@ class WorkerInput:
             blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
             blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
             blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
+            virtual_engine=tensor_dict["virtual_engine"],
         )
 
     def as_broadcastable_tensor_dict(
@@ -150,6 +153,7 @@ class WorkerInput:
             "blocks_to_swap_in": self.blocks_to_swap_in,
             "blocks_to_swap_out": self.blocks_to_swap_out,
             "blocks_to_copy": self.blocks_to_copy,
+            "virtual_engine": self.virtual_engine,
         }
 
         return tensor_dict
@@ -180,11 +184,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
 
     @property
     @abstractmethod
-    def kv_cache(self) -> Optional[List[torch.Tensor]]:
+    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
         """
-        Get the kv cache to pass to the worker's model runner. Used by the
-        default `execute_model`. If the worker's model runner does not follow
-        the ModelRunnerBase interface, then inherit from WorkerBase instead.
+        Gets the list of kv caches to pass to the worker's model runner. Each
+        element in the list is a kv cache corresponding to a particular virtual
+        engine (PP stream). Used by the default `execute_model`. If the worker's
+        model runner does not follow the ModelRunnerBase interface, then inherit
+        from WorkerBase instead.
         """
         raise NotImplementedError
 
@@ -225,7 +231,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
                 execute_model_req=execute_model_req)
             model_input: ModelRunnerInputBase = (
                 self.model_runner.prepare_model_input(
-                    execute_model_req.seq_group_metadata_list))
+                    execute_model_req.seq_group_metadata_list,
+                    execute_model_req.virtual_engine))
             num_steps = execute_model_req.num_steps
 
             if self.do_metadata_broadcast:
@@ -251,8 +258,23 @@ class LocalOrDistributedWorkerBase(WorkerBase):
         if worker_input.num_seq_groups == 0:
             return []
 
-        return self.model_runner.execute_model(model_input, self.kv_cache,
-                                               num_steps)
+        intermediate_tensors = None
+        if not get_pp_group().is_first_rank:
+            intermediate_tensors = IntermediateTensors(
+                get_pp_group().recv_tensor_dict())
+
+        output = self.model_runner.execute_model(
+            model_input, self.kv_cache[worker_input.virtual_engine]
+            if self.kv_cache is not None else None, intermediate_tensors,
+            num_steps)
+
+        if not get_pp_group().is_last_rank:
+            get_pp_group().send_tensor_dict(output.tensors)
+            return [None]
+
+        # Worker only supports single-step execution. Wrap the output in a
+        # list to conform to interface.
+        return output
 
 
 class WorkerWrapperBase:

+ 4 - 2
aphrodite/task_handler/xpu_model_runner.py

@@ -10,8 +10,8 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (SamplerOutput, SequenceData,
-                                       SequenceGroupMetadata)
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+                                       SequenceData, SequenceGroupMetadata)
 from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad
 from aphrodite.distributed import broadcast_tensor_dict
 from aphrodite.modeling.model_loader import get_model
@@ -190,6 +190,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
     def prepare_model_input(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
+        virtual_engine: int = 0,
     ) -> ModelInputForXPU:
         multi_modal_input = None
         if self.is_driver_worker:
@@ -334,6 +335,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
         self,
         model_input: ModelInputForXPU,
         kv_caches: List[torch.Tensor],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         num_steps: int = 1,
     ) -> Optional[List[SamplerOutput]]:
         if num_steps > 1:

+ 2 - 2
aphrodite/task_handler/xpu_worker.py

@@ -83,8 +83,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
         )
         # Uninitialized cache engine. Will be initialized by
         # initialize_cache.
-        self.cache_engine: CacheEngine
-        self.gpu_cache: List[torch.Tensor]
+        self.cache_engine: List[CacheEngine]
+        self.gpu_cache: Optional[List[List[torch.Tensor]]]
 
     def init_device(self) -> None:
         if self.device_config.device.type == "xpu" and is_xpu():