Procházet zdrojové kódy

refactor: eliminate parallel worker per-step task scheduling overhead

AlpinDale před 7 měsíci
rodič
revize
de62ceb18c

+ 8 - 0
aphrodite/engine/aphrodite_engine.py

@@ -637,6 +637,14 @@ class AphroditeEngine:
         # Log stats.
         self.do_log_stats(scheduler_outputs, output)
 
+        if not request_outputs:
+            # Stop the execute model loop in parallel workers until there are
+            # more requests to process. This avoids waiting indefinitely in
+            # torch.distributed ops which may otherwise timeout, and unblocks
+            # the RPC thread in the workers so that they can process any other
+            # queued control plane messages, such as add/remove lora adapters.
+            self.model_executor.stop_remote_worker_execution_loop()
+
         return request_outputs
 
     def do_log_stats(

+ 8 - 0
aphrodite/engine/async_aphrodite.py

@@ -238,6 +238,14 @@ class _AsyncAphrodite(AphroditeEngine):
         # Log stats.
         self.do_log_stats(scheduler_outputs, output)
 
+        if not request_outputs:
+            # Stop the execute model loop in parallel workers until there are
+            # more requests to process. This avoids waiting indefinitely in
+            # torch.distributed ops which may otherwise timeout, and unblocks
+            # the RPC thread in the workers so that they can process any other
+            # queued control plane messages, such as add/remove lora adapters.
+            await self.model_executor.stop_remote_worker_execution_loop_async()
+
         return request_outputs
 
     async def encode_request_async(

+ 94 - 26
aphrodite/executor/distributed_gpu_executor.py

@@ -1,17 +1,28 @@
+import asyncio
 from abc import abstractmethod
-from typing import Any, Dict, List, Optional, Set, Tuple
+from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
 
 from loguru import logger
 
+from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.sequence import SamplerOutput
 
 
 class DistributedGPUExecutor(GPUExecutor):
     """Abstract superclass of multi-GPU executor implementations."""
 
+    def __init__(self, *args, **kwargs):
+        # This is non-None when the execute model loop is running
+        # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
+        self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
+        # Updated by implementations that require additional args to be passed
+        # to the _run_workers execute_model call
+        self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
+
+        super().__init__(*args, **kwargs)
+
     def determine_num_available_blocks(self) -> Tuple[int, int]:
         """Determine the number of available KV blocks.
         This invokes `determine_num_available_blocks` on each worker and takes
@@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor):
                           num_gpu_blocks=num_gpu_blocks,
                           num_cpu_blocks=num_cpu_blocks)
 
-    def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
-        all_outputs = self._run_workers("execute_model",
-                                        driver_args=args,
-                                        driver_kwargs=kwargs)
+    def execute_model(
+            self,
+            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
+        if self.parallel_worker_tasks is None:
+            self.parallel_worker_tasks = self._run_workers(
+                "start_worker_execution_loop",
+                async_run_remote_workers_only=True,
+                **self.extra_execute_model_run_workers_kwargs)
 
         # Only the driver worker returns the sampling results.
-        return all_outputs[0]
+        return self._driver_execute_model(execute_model_req)
+
+    def stop_remote_worker_execution_loop(self) -> None:
+        if self.parallel_worker_tasks is None:
+            return
+
+        self._driver_execute_model()
+        parallel_worker_tasks = self.parallel_worker_tasks
+        self.parallel_worker_tasks = None
+        # Ensure that workers exit model loop cleanly
+        # (this will raise otherwise)
+        self._wait_for_tasks_completion(parallel_worker_tasks)
 
     def add_lora(self, lora_request: LoRARequest) -> bool:
         assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
@@ -88,39 +114,81 @@ class DistributedGPUExecutor(GPUExecutor):
                           pattern=pattern,
                           max_size=max_size)
 
+    @abstractmethod
+    def _driver_execute_model(
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        """Run execute_model in the driver worker.
+        Passing None will cause the driver to stop the model execution
+        loop running in each of the remote workers.
+        """
+        raise NotImplementedError
+
     @abstractmethod
     def _run_workers(
         self,
         method: str,
         *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
+        async_run_remote_workers_only: bool = False,
         max_concurrent_workers: Optional[int] = None,
         **kwargs,
     ) -> Any:
-        """Runs the given method on all workers."""
+        """Runs the given method on all workers.
+        Args:
+            async_run_remote_workers_only: If True the method will be run only
+                in the remote workers, not the driver worker. It will also be
+                run asynchronously and return a list of futures rather than
+                blocking on the results.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
+        """Wait for futures returned from _run_workers() with
+        async_run_remote_workers_only to complete."""
         raise NotImplementedError
 
 
 class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
 
+    async def execute_model_async(
+            self,
+            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
+        if self.parallel_worker_tasks is None:
+            # Start model execution loop running in the parallel workers
+            self.parallel_worker_tasks = asyncio.create_task(
+                self._start_worker_execution_loop())
+
+        # Only the driver worker returns the sampling results.
+        return await self._driver_execute_model_async(execute_model_req)
+
+    async def stop_remote_worker_execution_loop_async(self) -> None:
+        if self.parallel_worker_tasks is None:
+            return
+
+        await self._driver_execute_model_async()
+        parallel_worker_tasks = self.parallel_worker_tasks
+        self.parallel_worker_tasks = None
+        # Ensure that workers exit model loop cleanly
+        # (this will raise otherwise)
+        await parallel_worker_tasks
+
     @abstractmethod
-    async def _run_workers_async(
+    async def _driver_execute_model_async(
         self,
-        method: str,
-        *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        """Execute the model asynchronously in the driver worker.
+        Passing None will cause the driver to stop the model execution
+        loop running in each of the remote workers.
+        """
         raise NotImplementedError
 
-    async def execute_model_async(self, *args,
-                                  **kwargs) -> List[SamplerOutput]:
-        all_outputs = await self._run_workers_async("execute_model",
-                                                    driver_args=args,
-                                                    driver_kwargs=kwargs)
-
-        # Only the driver worker returns the sampling results.
-        return all_outputs[0]
+    @abstractmethod
+    async def _start_worker_execution_loop(self):
+        """Run execution loop on all workers. It guarantees all workers run
+        the loop or None of them is running the loop. Loop can be stopped by
+        `stop_remote_worker_execution_loop`.
+        The API is idempotent (guarantee only 1 loop run at any moment)."""
+        raise NotImplementedError

+ 8 - 0
aphrodite/executor/executor_base.py

@@ -75,6 +75,10 @@ class ExecutorBase(ABC):
         """Executes at least one model step on the given sequences."""
         raise NotImplementedError
 
+    def stop_remote_worker_execution_loop(self) -> None:
+        """Releases parallel workers from model loop."""
+        return
+
     @abstractmethod
     def add_lora(self, lora_request: LoRARequest) -> bool:
         raise NotImplementedError
@@ -110,6 +114,10 @@ class ExecutorAsyncBase(ExecutorBase):
         """Executes one model step on the given sequences."""
         raise NotImplementedError
 
+    async def stop_remote_worker_execution_loop_async(self) -> None:
+        """Releases parallel workers from model loop."""
+        return
+
     async def check_health_async(self) -> None:
         """Checks if the executor is healthy. If not, it should raise an
         exception."""

+ 42 - 28
aphrodite/executor/multiproc_gpu_executor.py

@@ -1,8 +1,9 @@
 import asyncio
 import os
 from functools import partial
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, List, Optional
 
+from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
 from aphrodite.common.utils import (get_aphrodite_instance_id,
                                     get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
@@ -67,16 +68,32 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
                                       None)) is not None:
             worker_monitor.close()
 
+    def _driver_execute_model(
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        """Run execute_model in the driver worker.
+        Passing None will cause the driver to stop the model execution
+        loop running in each of the remote workers.
+        """
+        return self.driver_worker.execute_model(
+            execute_model_req=execute_model_req)
+
     def _run_workers(
         self,
         method: str,
         *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
+        async_run_remote_workers_only: bool = False,
         max_concurrent_workers: Optional[int] = None,
         **kwargs,
     ) -> Any:
-        """Runs the given method on all workers."""
+        """Runs the given method on all workers.
+        Args:
+            async_run_remote_workers_only: If True the method will be run only
+                in the remote workers, not the driver worker. It will also be
+                run asynchronously and return a list of futures rather than
+                blocking on the results.
+        """
 
         if max_concurrent_workers:
             raise NotImplementedError(
@@ -88,15 +105,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
             for worker in self.workers
         ]
 
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
+        if async_run_remote_workers_only:
+            # Just return futures
+            return worker_outputs
 
-        # Start the driver worker after all the ray workers.
         driver_worker_method = getattr(self.driver_worker, method)
-        driver_worker_output = driver_worker_method(*driver_args,
-                                                    **driver_kwargs)
+        driver_worker_output = driver_worker_method(*args, **kwargs)
 
         # Get the results of the workers.
         return [driver_worker_output
@@ -107,29 +121,29 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
         if not self.worker_monitor.is_alive():
             raise RuntimeError("Worker processes are not running")
 
+    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
+        """Wait for futures returned from _run_workers() with
+        async_run_remote_workers_only to complete."""
+        for result in parallel_worker_tasks:
+            result.get()
+
 
 class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
                                       DistributedGPUExecutorAsync):
 
-    async def _run_workers_async(
-        self,
-        method: str,
-        *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.driver_exec_model = make_async(self.driver_worker.execute_model)
 
-        driver_executor = make_async(getattr(self.driver_worker, method))
+    async def _driver_execute_model_async(
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        return await self.driver_exec_model(execute_model_req)
 
-        # Run all the workers asynchronously.
-        coros = [driver_executor(*driver_args, **driver_kwargs)] + [
-            worker.execute_method_async(method, *args, **kwargs)
+    async def _start_worker_execution_loop(self):
+        coros = [
+            worker.execute_method_async("start_worker_execution_loop")
             for worker in self.workers
         ]
 

+ 43 - 44
aphrodite/executor/ray_gpu_executor.py

@@ -43,6 +43,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
         self.forward_dag = None
         if USE_RAY_COMPILED_DAG:
             self.forward_dag = self._compiled_ray_dag()
+            self.extra_execute_model_run_workers_kwargs[
+                "use_ray_compiled_dag"] = True
 
     def _configure_ray_workers_use_nsight(self,
                                           ray_remote_kwargs) -> Dict[str, Any]:
@@ -170,23 +172,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
                           max_concurrent_workers=self.parallel_config.
                           max_parallel_loading_workers)
 
-    def execute_model(
-            self,
-            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
-        all_outputs = self._run_workers(
-            "execute_model",
-            driver_kwargs={"execute_model_req": execute_model_req},
-            use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
-
-        # Only the driver worker returns the sampling results.
-        return all_outputs[0]
+    def _driver_execute_model(
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        """Run execute_model in the driver worker.
+        Passing None will cause the driver to stop the model execution
+        loop running in each of the remote workers.
+        """
+        return self.driver_worker.execute_method("execute_model",
+                                                 execute_model_req)
 
     def _run_workers(
         self,
         method: str,
         *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
+        async_run_remote_workers_only: bool = False,
         all_args: Optional[List[Tuple[Any, ...]]] = None,
         all_kwargs: Optional[List[Dict[str, Any]]] = None,
         use_dummy_driver: bool = False,
@@ -197,9 +198,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
         """Runs the given method on all workers. Can be used in the following
         ways:
 
+        - async_run_remote_workers_only: If True the method will be run only
+          in the remote workers, not the driver worker. It will also be
+          run asynchronously and return a list of futures rather than blocking
+          on the results.
         - args/kwargs: All workers share the same args/kwargs
-        - args/kwargs and driver_args/driver_kwargs: Driver worker has
-          different args
         - all_args/all_kwargs: args/kwargs for each worker are specified
           individually
         """
@@ -208,11 +211,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
             raise NotImplementedError(
                 "max_concurrent_workers is not supported yet.")
 
-        if driver_args is None:
-            driver_args = args if all_args is None else all_args[0]
-        if driver_kwargs is None:
-            driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
-
         count = len(self.workers)
         all_worker_args = repeat(args, count) if all_args is None \
             else islice(all_args, 1, None)
@@ -224,6 +222,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
             # input. TODO: Fix it.
             assert self.forward_dag is not None
             output_channels = self.forward_dag.execute(1)
+            ray_worker_outputs = []
         else:
             # Start the ray workers first.
             ray_worker_outputs = [
@@ -233,6 +232,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
                      ) in zip(self.workers, all_worker_args, all_worker_kwargs)
             ]
 
+        if async_run_remote_workers_only:
+            # Just return futures
+            return ray_worker_outputs
+
+        driver_args = args if all_args is None else all_args[0]
+        driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
+
         # Start the driver worker after all the ray workers.
         if not use_dummy_driver:
             driver_worker_output = self.driver_worker.execute_method(
@@ -258,6 +264,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
 
         return [driver_worker_output] + ray_worker_outputs
 
+    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
+        """Wait for futures returned from _run_workers() with
+        async_run_remote_workers_only to complete."""
+        ray.get(parallel_worker_tasks)
+
     def _compiled_ray_dag(self):
         import pkg_resources
         required_version = "2.9"
@@ -300,30 +311,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        self.driver_executor = make_async(self.driver_worker.execute_method)
+        self.driver_exec_method = make_async(self.driver_worker.execute_method)
 
-    async def _run_workers_async(
+    async def _driver_execute_model_async(
         self,
-        method: str,
-        *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
-        coros = []
-
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
-
-        coros.append(
-            self.driver_executor(method, *driver_args, **driver_kwargs))
-
-        # Run the ray workers asynchronously.
-        for worker in self.workers:
-            coros.append(worker.execute_method.remote(method, *args, **kwargs))
-
-        all_outputs = await asyncio.gather(*coros)
-        return all_outputs
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        return await self.driver_exec_method("execute_model",
+                                             execute_model_req)
+
+    async def _start_worker_execution_loop(self):
+        coros = [
+            worker.execute_method.remote("start_worker_execution_loop")
+            for worker in self.workers
+        ]
+        return await asyncio.gather(*coros)

+ 3 - 1
aphrodite/spec_decode/ngram_worker.py

@@ -47,7 +47,9 @@ class NGramWorker(LoraNotSupportedWorkerBase):
         # NGram don't need gpu sampler
         pass
 
-    def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
+    def execute_model(
+            self,
+            execute_model_req: Optional[ExecuteModelRequest] = None) -> None:
         """NGram doesn't depend on model execution, just pass this function"""
         pass
 

+ 61 - 62
aphrodite/spec_decode/spec_decode_worker.py

@@ -227,33 +227,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                               num_cpu_blocks=num_cpu_blocks)
 
-    def _broadcast_control_flow_decision(
-            self,
-            execute_model_req: Optional[ExecuteModelRequest] = None,
-            disable_all_speculation: bool = False) -> Tuple[int, bool]:
-        """Broadcast how many lookahead slots are scheduled for this step, and
-        whether all speculation is disabled, to all non-driver workers.
-        This is required as if the number of draft model runs changes
-        dynamically, the non-driver workers won't know unless we perform a
-        communication to inform then.
-        Returns the broadcasted num_lookahead_slots and disable_all_speculation.
-        """
-
-        if self.rank == self._driver_rank:
-            assert execute_model_req is not None
-
-            broadcast_dict = dict(
-                num_lookahead_slots=execute_model_req.num_lookahead_slots,
-                disable_all_speculation=disable_all_speculation,
-            )
-            broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
-        else:
-            assert execute_model_req is None
-            broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
-
-        return (broadcast_dict["num_lookahead_slots"],
-                broadcast_dict["disable_all_speculation"])
-
     @torch.inference_mode()
     def execute_model(
         self,
@@ -261,39 +234,58 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
     ) -> List[SamplerOutput]:
         """Perform speculative decoding on the input batch.
         """
+        if self.rank != self._driver_rank:
+            self._run_non_driver_rank()
+            return []
 
-        disable_all_speculation = False
-        if self.rank == self._driver_rank:
-            disable_all_speculation = self._should_disable_all_speculation(
-                execute_model_req)
-
-        (num_lookahead_slots,
-         disable_all_speculation) = self._broadcast_control_flow_decision(
-             execute_model_req, disable_all_speculation)
-
-        if self.rank == self._driver_rank:
-            assert execute_model_req is not None
-            assert execute_model_req.seq_group_metadata_list is not None, (
-                "speculative decoding requires non-None seq_group_metadata_list"
-            )
-
-            self._maybe_disable_speculative_tokens(
-                disable_all_speculation,
-                execute_model_req.seq_group_metadata_list)
-
-            # If no spec tokens, call the proposer and scorer workers normally.
-            # Used for prefill.
-            if num_lookahead_slots == 0 or len(
-                    execute_model_req.seq_group_metadata_list) == 0:
-                return self._run_no_spec(execute_model_req,
-                                         skip_proposer=disable_all_speculation)
-
-            return self._run_speculative_decoding_step(execute_model_req,
-                                                       num_lookahead_slots)
-        else:
-            self._run_non_driver_rank(num_lookahead_slots)
+        if execute_model_req is None:
+            # This signals that there's no more requests to process for now.
+            # All workers are running infinite loop with broadcast_tensor_dict,
+            # and it stops the loop when the driver broadcasts an empty input.
+            # Send an empty input to notify all other workers to stop their
+            # execution loop.
+            broadcast_tensor_dict({}, src=0)
             return []
 
+        disable_all_speculation = self._should_disable_all_speculation(
+            execute_model_req)
+        num_lookahead_slots = execute_model_req.num_lookahead_slots
+
+        # Broadcast how many lookahead slots are scheduled for this step, and
+        # whether all speculation is disabled, to all non-driver workers.
+
+        # This is required as if the number of draft model runs changes
+        # dynamically, the non-driver workers won't know unless we perform a
+        # communication to inform then.
+        broadcast_dict = dict(
+            num_lookahead_slots=num_lookahead_slots,
+            disable_all_speculation=disable_all_speculation,
+        )
+        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
+
+        assert execute_model_req.seq_group_metadata_list is not None, (
+            "speculative decoding requires non-None seq_group_metadata_list")
+
+        self._maybe_disable_speculative_tokens(
+            disable_all_speculation, execute_model_req.seq_group_metadata_list)
+
+        # If no spec tokens, call the proposer and scorer workers normally.
+        # Used for prefill.
+        if num_lookahead_slots == 0 or len(
+                execute_model_req.seq_group_metadata_list) == 0:
+            return self._run_no_spec(execute_model_req,
+                                     skip_proposer=disable_all_speculation)
+
+        return self._run_speculative_decoding_step(execute_model_req,
+                                                   num_lookahead_slots)
+
+    @torch.inference_mode()
+    def start_worker_execution_loop(self) -> None:
+        """Execute model loop to perform speculative decoding
+        in parallel worker."""
+        while self._run_non_driver_rank():
+            pass
+
     def _should_disable_all_speculation(
             self, execute_model_req: ExecuteModelRequest) -> bool:
         # When the batch size is too large, disable speculative decoding
@@ -340,13 +332,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         sampler_output.logprobs = None
         return [sampler_output]
 
-    def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
+    def _run_non_driver_rank(self) -> bool:
         """Run proposer and verifier model in non-driver workers. This is used
         for both speculation cases (num_lookahead_slots>0) and non-speculation
         cases (e.g. prefill).
+
+        Returns True iff there are remaining sequences to process.
         """
-        # In non-driver workers the input is None
-        execute_model_req = None
+        assert self.rank != self._driver_rank
+
+        data = broadcast_tensor_dict(src=self._driver_rank)
+        if not data:
+            return False
+        num_lookahead_slots = data["num_lookahead_slots"]
 
         # Even if num_lookahead_slots is zero, we want to run the proposer model
         # as it may have KV.
@@ -354,9 +352,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         # We run the proposer once per lookahead slot. In the future we should
         # delegate how many times it runs to the proposer.
         for _ in range(max(num_lookahead_slots, 1)):
-            self.proposer_worker.execute_model(execute_model_req)
+            self.proposer_worker.execute_model()
 
-        self.scorer_worker.execute_model(execute_model_req)
+        self.scorer_worker.execute_model()
+        return True
 
     @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
     def _run_speculative_decoding_step(

+ 3 - 2
aphrodite/task_handler/embedding_model_runner.py

@@ -45,7 +45,7 @@ class EmbeddingModelRunner(ModelRunner):
     @torch.inference_mode()
     def execute_model(
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
         kv_caches: List[torch.Tensor],
     ) -> Optional[PoolerOutput]:
         (input_tokens, input_positions, attn_metadata, pooling_metadata,
@@ -82,10 +82,11 @@ class EmbeddingModelRunner(ModelRunner):
 
     def prepare_input_tensors(
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
                Set[LoRARequest], LoRAMapping, torch.Tensor]:
         if self.is_driver_worker:
+            assert seq_group_metadata_list is not None
             # Prepare input tensors.
             (
                 input_tokens,

+ 3 - 2
aphrodite/task_handler/model_runner.py

@@ -646,10 +646,11 @@ class ModelRunner:
 
     def prepare_input_tensors(
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
                Set[LoRARequest], LoRAMapping, torch.Tensor]:
         if self.is_driver_worker:
+            assert seq_group_metadata_list is not None
             # Prepare input tensors.
             (
                 input_tokens,
@@ -713,7 +714,7 @@ class ModelRunner:
     @torch.inference_mode()
     def execute_model(
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
         kv_caches: List[torch.Tensor],
     ) -> Optional[SamplerOutput]:
         (input_tokens, input_positions, attn_metadata, sampling_metadata,

+ 63 - 38
aphrodite/task_handler/worker.py

@@ -229,48 +229,42 @@ class Worker(WorkerBase):
         self,
         execute_model_req: Optional[ExecuteModelRequest] = None
     ) -> List[Union[SamplerOutput, PoolerOutput]]:
+        if not self.is_driver_worker:
+            self._execute_model_non_driver()
+            return []
 
         if execute_model_req is None:
-            seq_group_metadata_list = None
-        else:
-            seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+            # This signals that there's no more requests to process for now.
+            # All workers are running infinite loop with broadcast_tensor_dict,
+            # and it stops the loop when the driver broadcasts an empty input.
+            # Send an empty input to notify all other workers to stop their
+            # execution loop.
+            broadcast_tensor_dict({}, src=0)
+            return []
 
-        blocks_to_swap_in: torch.Tensor
-        blocks_to_swap_out: torch.Tensor
-        blocks_to_copy: torch.Tensor
-        if self.is_driver_worker:
-            assert seq_group_metadata_list is not None
-            assert execute_model_req is not None
-            num_seq_groups = len(seq_group_metadata_list)
-            # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
-            # they contain parameters to launch cudamemcpyasync.
-            blocks_to_swap_in = torch.tensor(
-                execute_model_req.blocks_to_swap_in,
-                device="cpu",
-                dtype=torch.int64).view(-1, 2)
-            blocks_to_swap_out = torch.tensor(
-                execute_model_req.blocks_to_swap_out,
-                device="cpu",
-                dtype=torch.int64).view(-1, 2)
-            # `blocks_to_copy` is a gpu tensor. The src and tgt of
-            # blocks to copy are in the same device, and `blocks_to_copy`
-            # can be used directly within cuda kernels.
-            blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
-                                          device=self.device,
+        seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+        num_seq_groups = len(seq_group_metadata_list)
+        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
+        # they contain parameters to launch cudamemcpyasync.
+        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
+                                         device="cpu",
+                                         dtype=torch.int64).view(-1, 2)
+        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
+                                          device="cpu",
                                           dtype=torch.int64).view(-1, 2)
-            data: Dict[str, Any] = {
-                "num_seq_groups": num_seq_groups,
-                "blocks_to_swap_in": blocks_to_swap_in,
-                "blocks_to_swap_out": blocks_to_swap_out,
-                "blocks_to_copy": blocks_to_copy,
-            }
-            broadcast_tensor_dict(data, src=0)
-        else:
-            data = broadcast_tensor_dict(src=0)
-            num_seq_groups = data["num_seq_groups"]
-            blocks_to_swap_in = data["blocks_to_swap_in"]
-            blocks_to_swap_out = data["blocks_to_swap_out"]
-            blocks_to_copy = data["blocks_to_copy"]
+        # `blocks_to_copy` is a gpu tensor. The src and tgt of
+        # blocks to copy are in the same device, and `blocks_to_copy`
+        # can be used directly within cuda kernels.
+        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
+                                      device=self.device,
+                                      dtype=torch.int64).view(-1, 2)
+        data: Dict[str, Any] = {
+            "num_seq_groups": num_seq_groups,
+            "blocks_to_swap_in": blocks_to_swap_in,
+            "blocks_to_swap_out": blocks_to_swap_out,
+            "blocks_to_copy": blocks_to_copy,
+        }
+        broadcast_tensor_dict(data, src=0)
 
         self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
 
@@ -285,6 +279,37 @@ class Worker(WorkerBase):
         # to conform to interface.
         return [output]
 
+    @torch.inference_mode()
+    def start_worker_execution_loop(self) -> None:
+        """Execute model loop in parallel worker.
+        You can stop the loop by executing a driver worker with an empty output.
+        See `stop_remote_worker_execution_loop` for more details.
+        """
+        while self._execute_model_non_driver():
+            pass
+
+    def _execute_model_non_driver(self) -> bool:
+        """Execute model in parallel worker.
+        Returns True iff there are remaining sequences to process.
+        """
+        assert not self.is_driver_worker
+        data = broadcast_tensor_dict(src=0)
+        if not data:
+            return False
+
+        num_seq_groups = data.get("num_seq_groups", 0)
+        blocks_to_swap_in = data.get("blocks_to_swap_in")
+        blocks_to_swap_out = data.get("blocks_to_swap_out")
+        blocks_to_copy = data.get("blocks_to_copy")
+        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
+
+        # If there is no input, we don't need to execute the model.
+        if num_seq_groups == 0:
+            return False
+
+        self.model_runner.execute_model(None, self.gpu_cache)
+        return True
+
     def add_lora(self, lora_request: LoRARequest) -> bool:
         return self.model_runner.add_lora(lora_request)
 

+ 4 - 3
aphrodite/task_handler/worker_base.py

@@ -1,7 +1,7 @@
 import importlib
 import os
 from abc import ABC, abstractmethod
-from typing import Dict, List, Set, Tuple
+from typing import Dict, List, Optional, Set, Tuple
 
 from loguru import logger
 
@@ -47,8 +47,9 @@ class WorkerBase(ABC):
 
     @abstractmethod
     def execute_model(
-            self,
-            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
         """Executes at least one model step on the given sequences, unless no
         sequences are provided."""
         raise NotImplementedError