Browse Source

refactor: eliminate parallel worker per-step task scheduling overhead

AlpinDale 7 months ago
parent
commit
de62ceb18c

+ 8 - 0
aphrodite/engine/aphrodite_engine.py

@@ -637,6 +637,14 @@ class AphroditeEngine:
         # Log stats.
         # Log stats.
         self.do_log_stats(scheduler_outputs, output)
         self.do_log_stats(scheduler_outputs, output)
 
 
+        if not request_outputs:
+            # Stop the execute model loop in parallel workers until there are
+            # more requests to process. This avoids waiting indefinitely in
+            # torch.distributed ops which may otherwise timeout, and unblocks
+            # the RPC thread in the workers so that they can process any other
+            # queued control plane messages, such as add/remove lora adapters.
+            self.model_executor.stop_remote_worker_execution_loop()
+
         return request_outputs
         return request_outputs
 
 
     def do_log_stats(
     def do_log_stats(

+ 8 - 0
aphrodite/engine/async_aphrodite.py

@@ -238,6 +238,14 @@ class _AsyncAphrodite(AphroditeEngine):
         # Log stats.
         # Log stats.
         self.do_log_stats(scheduler_outputs, output)
         self.do_log_stats(scheduler_outputs, output)
 
 
+        if not request_outputs:
+            # Stop the execute model loop in parallel workers until there are
+            # more requests to process. This avoids waiting indefinitely in
+            # torch.distributed ops which may otherwise timeout, and unblocks
+            # the RPC thread in the workers so that they can process any other
+            # queued control plane messages, such as add/remove lora adapters.
+            await self.model_executor.stop_remote_worker_execution_loop_async()
+
         return request_outputs
         return request_outputs
 
 
     async def encode_request_async(
     async def encode_request_async(

+ 94 - 26
aphrodite/executor/distributed_gpu_executor.py

@@ -1,17 +1,28 @@
+import asyncio
 from abc import abstractmethod
 from abc import abstractmethod
-from typing import Any, Dict, List, Optional, Set, Tuple
+from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
 
 
 from loguru import logger
 from loguru import logger
 
 
+from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.sequence import SamplerOutput
 
 
 
 
 class DistributedGPUExecutor(GPUExecutor):
 class DistributedGPUExecutor(GPUExecutor):
     """Abstract superclass of multi-GPU executor implementations."""
     """Abstract superclass of multi-GPU executor implementations."""
 
 
+    def __init__(self, *args, **kwargs):
+        # This is non-None when the execute model loop is running
+        # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
+        self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
+        # Updated by implementations that require additional args to be passed
+        # to the _run_workers execute_model call
+        self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
+
+        super().__init__(*args, **kwargs)
+
     def determine_num_available_blocks(self) -> Tuple[int, int]:
     def determine_num_available_blocks(self) -> Tuple[int, int]:
         """Determine the number of available KV blocks.
         """Determine the number of available KV blocks.
         This invokes `determine_num_available_blocks` on each worker and takes
         This invokes `determine_num_available_blocks` on each worker and takes
@@ -52,13 +63,28 @@ class DistributedGPUExecutor(GPUExecutor):
                           num_gpu_blocks=num_gpu_blocks,
                           num_gpu_blocks=num_gpu_blocks,
                           num_cpu_blocks=num_cpu_blocks)
                           num_cpu_blocks=num_cpu_blocks)
 
 
-    def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
-        all_outputs = self._run_workers("execute_model",
-                                        driver_args=args,
-                                        driver_kwargs=kwargs)
+    def execute_model(
+            self,
+            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
+        if self.parallel_worker_tasks is None:
+            self.parallel_worker_tasks = self._run_workers(
+                "start_worker_execution_loop",
+                async_run_remote_workers_only=True,
+                **self.extra_execute_model_run_workers_kwargs)
 
 
         # Only the driver worker returns the sampling results.
         # Only the driver worker returns the sampling results.
-        return all_outputs[0]
+        return self._driver_execute_model(execute_model_req)
+
+    def stop_remote_worker_execution_loop(self) -> None:
+        if self.parallel_worker_tasks is None:
+            return
+
+        self._driver_execute_model()
+        parallel_worker_tasks = self.parallel_worker_tasks
+        self.parallel_worker_tasks = None
+        # Ensure that workers exit model loop cleanly
+        # (this will raise otherwise)
+        self._wait_for_tasks_completion(parallel_worker_tasks)
 
 
     def add_lora(self, lora_request: LoRARequest) -> bool:
     def add_lora(self, lora_request: LoRARequest) -> bool:
         assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
         assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
@@ -88,39 +114,81 @@ class DistributedGPUExecutor(GPUExecutor):
                           pattern=pattern,
                           pattern=pattern,
                           max_size=max_size)
                           max_size=max_size)
 
 
+    @abstractmethod
+    def _driver_execute_model(
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        """Run execute_model in the driver worker.
+        Passing None will cause the driver to stop the model execution
+        loop running in each of the remote workers.
+        """
+        raise NotImplementedError
+
     @abstractmethod
     @abstractmethod
     def _run_workers(
     def _run_workers(
         self,
         self,
         method: str,
         method: str,
         *args,
         *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
+        async_run_remote_workers_only: bool = False,
         max_concurrent_workers: Optional[int] = None,
         max_concurrent_workers: Optional[int] = None,
         **kwargs,
         **kwargs,
     ) -> Any:
     ) -> Any:
-        """Runs the given method on all workers."""
+        """Runs the given method on all workers.
+        Args:
+            async_run_remote_workers_only: If True the method will be run only
+                in the remote workers, not the driver worker. It will also be
+                run asynchronously and return a list of futures rather than
+                blocking on the results.
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
+        """Wait for futures returned from _run_workers() with
+        async_run_remote_workers_only to complete."""
         raise NotImplementedError
         raise NotImplementedError
 
 
 
 
 class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
 class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
 
 
+    async def execute_model_async(
+            self,
+            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
+        if self.parallel_worker_tasks is None:
+            # Start model execution loop running in the parallel workers
+            self.parallel_worker_tasks = asyncio.create_task(
+                self._start_worker_execution_loop())
+
+        # Only the driver worker returns the sampling results.
+        return await self._driver_execute_model_async(execute_model_req)
+
+    async def stop_remote_worker_execution_loop_async(self) -> None:
+        if self.parallel_worker_tasks is None:
+            return
+
+        await self._driver_execute_model_async()
+        parallel_worker_tasks = self.parallel_worker_tasks
+        self.parallel_worker_tasks = None
+        # Ensure that workers exit model loop cleanly
+        # (this will raise otherwise)
+        await parallel_worker_tasks
+
     @abstractmethod
     @abstractmethod
-    async def _run_workers_async(
+    async def _driver_execute_model_async(
         self,
         self,
-        method: str,
-        *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        """Execute the model asynchronously in the driver worker.
+        Passing None will cause the driver to stop the model execution
+        loop running in each of the remote workers.
+        """
         raise NotImplementedError
         raise NotImplementedError
 
 
-    async def execute_model_async(self, *args,
-                                  **kwargs) -> List[SamplerOutput]:
-        all_outputs = await self._run_workers_async("execute_model",
-                                                    driver_args=args,
-                                                    driver_kwargs=kwargs)
-
-        # Only the driver worker returns the sampling results.
-        return all_outputs[0]
+    @abstractmethod
+    async def _start_worker_execution_loop(self):
+        """Run execution loop on all workers. It guarantees all workers run
+        the loop or None of them is running the loop. Loop can be stopped by
+        `stop_remote_worker_execution_loop`.
+        The API is idempotent (guarantee only 1 loop run at any moment)."""
+        raise NotImplementedError

+ 8 - 0
aphrodite/executor/executor_base.py

@@ -75,6 +75,10 @@ class ExecutorBase(ABC):
         """Executes at least one model step on the given sequences."""
         """Executes at least one model step on the given sequences."""
         raise NotImplementedError
         raise NotImplementedError
 
 
+    def stop_remote_worker_execution_loop(self) -> None:
+        """Releases parallel workers from model loop."""
+        return
+
     @abstractmethod
     @abstractmethod
     def add_lora(self, lora_request: LoRARequest) -> bool:
     def add_lora(self, lora_request: LoRARequest) -> bool:
         raise NotImplementedError
         raise NotImplementedError
@@ -110,6 +114,10 @@ class ExecutorAsyncBase(ExecutorBase):
         """Executes one model step on the given sequences."""
         """Executes one model step on the given sequences."""
         raise NotImplementedError
         raise NotImplementedError
 
 
+    async def stop_remote_worker_execution_loop_async(self) -> None:
+        """Releases parallel workers from model loop."""
+        return
+
     async def check_health_async(self) -> None:
     async def check_health_async(self) -> None:
         """Checks if the executor is healthy. If not, it should raise an
         """Checks if the executor is healthy. If not, it should raise an
         exception."""
         exception."""

+ 42 - 28
aphrodite/executor/multiproc_gpu_executor.py

@@ -1,8 +1,9 @@
 import asyncio
 import asyncio
 import os
 import os
 from functools import partial
 from functools import partial
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, List, Optional
 
 
+from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
 from aphrodite.common.utils import (get_aphrodite_instance_id,
 from aphrodite.common.utils import (get_aphrodite_instance_id,
                                     get_distributed_init_method, get_ip,
                                     get_distributed_init_method, get_ip,
                                     get_open_port, make_async)
                                     get_open_port, make_async)
@@ -67,16 +68,32 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
                                       None)) is not None:
                                       None)) is not None:
             worker_monitor.close()
             worker_monitor.close()
 
 
+    def _driver_execute_model(
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        """Run execute_model in the driver worker.
+        Passing None will cause the driver to stop the model execution
+        loop running in each of the remote workers.
+        """
+        return self.driver_worker.execute_model(
+            execute_model_req=execute_model_req)
+
     def _run_workers(
     def _run_workers(
         self,
         self,
         method: str,
         method: str,
         *args,
         *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
+        async_run_remote_workers_only: bool = False,
         max_concurrent_workers: Optional[int] = None,
         max_concurrent_workers: Optional[int] = None,
         **kwargs,
         **kwargs,
     ) -> Any:
     ) -> Any:
-        """Runs the given method on all workers."""
+        """Runs the given method on all workers.
+        Args:
+            async_run_remote_workers_only: If True the method will be run only
+                in the remote workers, not the driver worker. It will also be
+                run asynchronously and return a list of futures rather than
+                blocking on the results.
+        """
 
 
         if max_concurrent_workers:
         if max_concurrent_workers:
             raise NotImplementedError(
             raise NotImplementedError(
@@ -88,15 +105,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
             for worker in self.workers
             for worker in self.workers
         ]
         ]
 
 
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
+        if async_run_remote_workers_only:
+            # Just return futures
+            return worker_outputs
 
 
-        # Start the driver worker after all the ray workers.
         driver_worker_method = getattr(self.driver_worker, method)
         driver_worker_method = getattr(self.driver_worker, method)
-        driver_worker_output = driver_worker_method(*driver_args,
-                                                    **driver_kwargs)
+        driver_worker_output = driver_worker_method(*args, **kwargs)
 
 
         # Get the results of the workers.
         # Get the results of the workers.
         return [driver_worker_output
         return [driver_worker_output
@@ -107,29 +121,29 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
         if not self.worker_monitor.is_alive():
         if not self.worker_monitor.is_alive():
             raise RuntimeError("Worker processes are not running")
             raise RuntimeError("Worker processes are not running")
 
 
+    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
+        """Wait for futures returned from _run_workers() with
+        async_run_remote_workers_only to complete."""
+        for result in parallel_worker_tasks:
+            result.get()
+
 
 
 class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
 class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
                                       DistributedGPUExecutorAsync):
                                       DistributedGPUExecutorAsync):
 
 
-    async def _run_workers_async(
-        self,
-        method: str,
-        *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.driver_exec_model = make_async(self.driver_worker.execute_model)
 
 
-        driver_executor = make_async(getattr(self.driver_worker, method))
+    async def _driver_execute_model_async(
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        return await self.driver_exec_model(execute_model_req)
 
 
-        # Run all the workers asynchronously.
-        coros = [driver_executor(*driver_args, **driver_kwargs)] + [
-            worker.execute_method_async(method, *args, **kwargs)
+    async def _start_worker_execution_loop(self):
+        coros = [
+            worker.execute_method_async("start_worker_execution_loop")
             for worker in self.workers
             for worker in self.workers
         ]
         ]
 
 

+ 43 - 44
aphrodite/executor/ray_gpu_executor.py

@@ -43,6 +43,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
         self.forward_dag = None
         self.forward_dag = None
         if USE_RAY_COMPILED_DAG:
         if USE_RAY_COMPILED_DAG:
             self.forward_dag = self._compiled_ray_dag()
             self.forward_dag = self._compiled_ray_dag()
+            self.extra_execute_model_run_workers_kwargs[
+                "use_ray_compiled_dag"] = True
 
 
     def _configure_ray_workers_use_nsight(self,
     def _configure_ray_workers_use_nsight(self,
                                           ray_remote_kwargs) -> Dict[str, Any]:
                                           ray_remote_kwargs) -> Dict[str, Any]:
@@ -170,23 +172,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
                           max_concurrent_workers=self.parallel_config.
                           max_concurrent_workers=self.parallel_config.
                           max_parallel_loading_workers)
                           max_parallel_loading_workers)
 
 
-    def execute_model(
-            self,
-            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
-        all_outputs = self._run_workers(
-            "execute_model",
-            driver_kwargs={"execute_model_req": execute_model_req},
-            use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
-
-        # Only the driver worker returns the sampling results.
-        return all_outputs[0]
+    def _driver_execute_model(
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        """Run execute_model in the driver worker.
+        Passing None will cause the driver to stop the model execution
+        loop running in each of the remote workers.
+        """
+        return self.driver_worker.execute_method("execute_model",
+                                                 execute_model_req)
 
 
     def _run_workers(
     def _run_workers(
         self,
         self,
         method: str,
         method: str,
         *args,
         *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
+        async_run_remote_workers_only: bool = False,
         all_args: Optional[List[Tuple[Any, ...]]] = None,
         all_args: Optional[List[Tuple[Any, ...]]] = None,
         all_kwargs: Optional[List[Dict[str, Any]]] = None,
         all_kwargs: Optional[List[Dict[str, Any]]] = None,
         use_dummy_driver: bool = False,
         use_dummy_driver: bool = False,
@@ -197,9 +198,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
         """Runs the given method on all workers. Can be used in the following
         """Runs the given method on all workers. Can be used in the following
         ways:
         ways:
 
 
+        - async_run_remote_workers_only: If True the method will be run only
+          in the remote workers, not the driver worker. It will also be
+          run asynchronously and return a list of futures rather than blocking
+          on the results.
         - args/kwargs: All workers share the same args/kwargs
         - args/kwargs: All workers share the same args/kwargs
-        - args/kwargs and driver_args/driver_kwargs: Driver worker has
-          different args
         - all_args/all_kwargs: args/kwargs for each worker are specified
         - all_args/all_kwargs: args/kwargs for each worker are specified
           individually
           individually
         """
         """
@@ -208,11 +211,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
             raise NotImplementedError(
             raise NotImplementedError(
                 "max_concurrent_workers is not supported yet.")
                 "max_concurrent_workers is not supported yet.")
 
 
-        if driver_args is None:
-            driver_args = args if all_args is None else all_args[0]
-        if driver_kwargs is None:
-            driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
-
         count = len(self.workers)
         count = len(self.workers)
         all_worker_args = repeat(args, count) if all_args is None \
         all_worker_args = repeat(args, count) if all_args is None \
             else islice(all_args, 1, None)
             else islice(all_args, 1, None)
@@ -224,6 +222,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
             # input. TODO: Fix it.
             # input. TODO: Fix it.
             assert self.forward_dag is not None
             assert self.forward_dag is not None
             output_channels = self.forward_dag.execute(1)
             output_channels = self.forward_dag.execute(1)
+            ray_worker_outputs = []
         else:
         else:
             # Start the ray workers first.
             # Start the ray workers first.
             ray_worker_outputs = [
             ray_worker_outputs = [
@@ -233,6 +232,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
                      ) in zip(self.workers, all_worker_args, all_worker_kwargs)
                      ) in zip(self.workers, all_worker_args, all_worker_kwargs)
             ]
             ]
 
 
+        if async_run_remote_workers_only:
+            # Just return futures
+            return ray_worker_outputs
+
+        driver_args = args if all_args is None else all_args[0]
+        driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
+
         # Start the driver worker after all the ray workers.
         # Start the driver worker after all the ray workers.
         if not use_dummy_driver:
         if not use_dummy_driver:
             driver_worker_output = self.driver_worker.execute_method(
             driver_worker_output = self.driver_worker.execute_method(
@@ -258,6 +264,11 @@ class RayGPUExecutor(DistributedGPUExecutor):
 
 
         return [driver_worker_output] + ray_worker_outputs
         return [driver_worker_output] + ray_worker_outputs
 
 
+    def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
+        """Wait for futures returned from _run_workers() with
+        async_run_remote_workers_only to complete."""
+        ray.get(parallel_worker_tasks)
+
     def _compiled_ray_dag(self):
     def _compiled_ray_dag(self):
         import pkg_resources
         import pkg_resources
         required_version = "2.9"
         required_version = "2.9"
@@ -300,30 +311,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
-        self.driver_executor = make_async(self.driver_worker.execute_method)
+        self.driver_exec_method = make_async(self.driver_worker.execute_method)
 
 
-    async def _run_workers_async(
+    async def _driver_execute_model_async(
         self,
         self,
-        method: str,
-        *args,
-        driver_args: Optional[Tuple[Any, ...]] = None,
-        driver_kwargs: Optional[Dict[str, Any]] = None,
-        **kwargs,
-    ) -> Any:
-        """Runs the given method on all workers."""
-        coros = []
-
-        if driver_args is None:
-            driver_args = args
-        if driver_kwargs is None:
-            driver_kwargs = kwargs
-
-        coros.append(
-            self.driver_executor(method, *driver_args, **driver_kwargs))
-
-        # Run the ray workers asynchronously.
-        for worker in self.workers:
-            coros.append(worker.execute_method.remote(method, *args, **kwargs))
-
-        all_outputs = await asyncio.gather(*coros)
-        return all_outputs
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        return await self.driver_exec_method("execute_model",
+                                             execute_model_req)
+
+    async def _start_worker_execution_loop(self):
+        coros = [
+            worker.execute_method.remote("start_worker_execution_loop")
+            for worker in self.workers
+        ]
+        return await asyncio.gather(*coros)

+ 3 - 1
aphrodite/spec_decode/ngram_worker.py

@@ -47,7 +47,9 @@ class NGramWorker(LoraNotSupportedWorkerBase):
         # NGram don't need gpu sampler
         # NGram don't need gpu sampler
         pass
         pass
 
 
-    def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
+    def execute_model(
+            self,
+            execute_model_req: Optional[ExecuteModelRequest] = None) -> None:
         """NGram doesn't depend on model execution, just pass this function"""
         """NGram doesn't depend on model execution, just pass this function"""
         pass
         pass
 
 

+ 61 - 62
aphrodite/spec_decode/spec_decode_worker.py

@@ -227,33 +227,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
         self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
                                               num_cpu_blocks=num_cpu_blocks)
                                               num_cpu_blocks=num_cpu_blocks)
 
 
-    def _broadcast_control_flow_decision(
-            self,
-            execute_model_req: Optional[ExecuteModelRequest] = None,
-            disable_all_speculation: bool = False) -> Tuple[int, bool]:
-        """Broadcast how many lookahead slots are scheduled for this step, and
-        whether all speculation is disabled, to all non-driver workers.
-        This is required as if the number of draft model runs changes
-        dynamically, the non-driver workers won't know unless we perform a
-        communication to inform then.
-        Returns the broadcasted num_lookahead_slots and disable_all_speculation.
-        """
-
-        if self.rank == self._driver_rank:
-            assert execute_model_req is not None
-
-            broadcast_dict = dict(
-                num_lookahead_slots=execute_model_req.num_lookahead_slots,
-                disable_all_speculation=disable_all_speculation,
-            )
-            broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
-        else:
-            assert execute_model_req is None
-            broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
-
-        return (broadcast_dict["num_lookahead_slots"],
-                broadcast_dict["disable_all_speculation"])
-
     @torch.inference_mode()
     @torch.inference_mode()
     def execute_model(
     def execute_model(
         self,
         self,
@@ -261,39 +234,58 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
     ) -> List[SamplerOutput]:
     ) -> List[SamplerOutput]:
         """Perform speculative decoding on the input batch.
         """Perform speculative decoding on the input batch.
         """
         """
+        if self.rank != self._driver_rank:
+            self._run_non_driver_rank()
+            return []
 
 
-        disable_all_speculation = False
-        if self.rank == self._driver_rank:
-            disable_all_speculation = self._should_disable_all_speculation(
-                execute_model_req)
-
-        (num_lookahead_slots,
-         disable_all_speculation) = self._broadcast_control_flow_decision(
-             execute_model_req, disable_all_speculation)
-
-        if self.rank == self._driver_rank:
-            assert execute_model_req is not None
-            assert execute_model_req.seq_group_metadata_list is not None, (
-                "speculative decoding requires non-None seq_group_metadata_list"
-            )
-
-            self._maybe_disable_speculative_tokens(
-                disable_all_speculation,
-                execute_model_req.seq_group_metadata_list)
-
-            # If no spec tokens, call the proposer and scorer workers normally.
-            # Used for prefill.
-            if num_lookahead_slots == 0 or len(
-                    execute_model_req.seq_group_metadata_list) == 0:
-                return self._run_no_spec(execute_model_req,
-                                         skip_proposer=disable_all_speculation)
-
-            return self._run_speculative_decoding_step(execute_model_req,
-                                                       num_lookahead_slots)
-        else:
-            self._run_non_driver_rank(num_lookahead_slots)
+        if execute_model_req is None:
+            # This signals that there's no more requests to process for now.
+            # All workers are running infinite loop with broadcast_tensor_dict,
+            # and it stops the loop when the driver broadcasts an empty input.
+            # Send an empty input to notify all other workers to stop their
+            # execution loop.
+            broadcast_tensor_dict({}, src=0)
             return []
             return []
 
 
+        disable_all_speculation = self._should_disable_all_speculation(
+            execute_model_req)
+        num_lookahead_slots = execute_model_req.num_lookahead_slots
+
+        # Broadcast how many lookahead slots are scheduled for this step, and
+        # whether all speculation is disabled, to all non-driver workers.
+
+        # This is required as if the number of draft model runs changes
+        # dynamically, the non-driver workers won't know unless we perform a
+        # communication to inform then.
+        broadcast_dict = dict(
+            num_lookahead_slots=num_lookahead_slots,
+            disable_all_speculation=disable_all_speculation,
+        )
+        broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
+
+        assert execute_model_req.seq_group_metadata_list is not None, (
+            "speculative decoding requires non-None seq_group_metadata_list")
+
+        self._maybe_disable_speculative_tokens(
+            disable_all_speculation, execute_model_req.seq_group_metadata_list)
+
+        # If no spec tokens, call the proposer and scorer workers normally.
+        # Used for prefill.
+        if num_lookahead_slots == 0 or len(
+                execute_model_req.seq_group_metadata_list) == 0:
+            return self._run_no_spec(execute_model_req,
+                                     skip_proposer=disable_all_speculation)
+
+        return self._run_speculative_decoding_step(execute_model_req,
+                                                   num_lookahead_slots)
+
+    @torch.inference_mode()
+    def start_worker_execution_loop(self) -> None:
+        """Execute model loop to perform speculative decoding
+        in parallel worker."""
+        while self._run_non_driver_rank():
+            pass
+
     def _should_disable_all_speculation(
     def _should_disable_all_speculation(
             self, execute_model_req: ExecuteModelRequest) -> bool:
             self, execute_model_req: ExecuteModelRequest) -> bool:
         # When the batch size is too large, disable speculative decoding
         # When the batch size is too large, disable speculative decoding
@@ -340,13 +332,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         sampler_output.logprobs = None
         sampler_output.logprobs = None
         return [sampler_output]
         return [sampler_output]
 
 
-    def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
+    def _run_non_driver_rank(self) -> bool:
         """Run proposer and verifier model in non-driver workers. This is used
         """Run proposer and verifier model in non-driver workers. This is used
         for both speculation cases (num_lookahead_slots>0) and non-speculation
         for both speculation cases (num_lookahead_slots>0) and non-speculation
         cases (e.g. prefill).
         cases (e.g. prefill).
+
+        Returns True iff there are remaining sequences to process.
         """
         """
-        # In non-driver workers the input is None
-        execute_model_req = None
+        assert self.rank != self._driver_rank
+
+        data = broadcast_tensor_dict(src=self._driver_rank)
+        if not data:
+            return False
+        num_lookahead_slots = data["num_lookahead_slots"]
 
 
         # Even if num_lookahead_slots is zero, we want to run the proposer model
         # Even if num_lookahead_slots is zero, we want to run the proposer model
         # as it may have KV.
         # as it may have KV.
@@ -354,9 +352,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
         # We run the proposer once per lookahead slot. In the future we should
         # We run the proposer once per lookahead slot. In the future we should
         # delegate how many times it runs to the proposer.
         # delegate how many times it runs to the proposer.
         for _ in range(max(num_lookahead_slots, 1)):
         for _ in range(max(num_lookahead_slots, 1)):
-            self.proposer_worker.execute_model(execute_model_req)
+            self.proposer_worker.execute_model()
 
 
-        self.scorer_worker.execute_model(execute_model_req)
+        self.scorer_worker.execute_model()
+        return True
 
 
     @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
     @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
     def _run_speculative_decoding_step(
     def _run_speculative_decoding_step(

+ 3 - 2
aphrodite/task_handler/embedding_model_runner.py

@@ -45,7 +45,7 @@ class EmbeddingModelRunner(ModelRunner):
     @torch.inference_mode()
     @torch.inference_mode()
     def execute_model(
     def execute_model(
         self,
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
         kv_caches: List[torch.Tensor],
         kv_caches: List[torch.Tensor],
     ) -> Optional[PoolerOutput]:
     ) -> Optional[PoolerOutput]:
         (input_tokens, input_positions, attn_metadata, pooling_metadata,
         (input_tokens, input_positions, attn_metadata, pooling_metadata,
@@ -82,10 +82,11 @@ class EmbeddingModelRunner(ModelRunner):
 
 
     def prepare_input_tensors(
     def prepare_input_tensors(
         self,
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
                Set[LoRARequest], LoRAMapping, torch.Tensor]:
                Set[LoRARequest], LoRAMapping, torch.Tensor]:
         if self.is_driver_worker:
         if self.is_driver_worker:
+            assert seq_group_metadata_list is not None
             # Prepare input tensors.
             # Prepare input tensors.
             (
             (
                 input_tokens,
                 input_tokens,

+ 3 - 2
aphrodite/task_handler/model_runner.py

@@ -646,10 +646,11 @@ class ModelRunner:
 
 
     def prepare_input_tensors(
     def prepare_input_tensors(
         self,
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
     ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
                Set[LoRARequest], LoRAMapping, torch.Tensor]:
                Set[LoRARequest], LoRAMapping, torch.Tensor]:
         if self.is_driver_worker:
         if self.is_driver_worker:
+            assert seq_group_metadata_list is not None
             # Prepare input tensors.
             # Prepare input tensors.
             (
             (
                 input_tokens,
                 input_tokens,
@@ -713,7 +714,7 @@ class ModelRunner:
     @torch.inference_mode()
     @torch.inference_mode()
     def execute_model(
     def execute_model(
         self,
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
         kv_caches: List[torch.Tensor],
         kv_caches: List[torch.Tensor],
     ) -> Optional[SamplerOutput]:
     ) -> Optional[SamplerOutput]:
         (input_tokens, input_positions, attn_metadata, sampling_metadata,
         (input_tokens, input_positions, attn_metadata, sampling_metadata,

+ 63 - 38
aphrodite/task_handler/worker.py

@@ -229,48 +229,42 @@ class Worker(WorkerBase):
         self,
         self,
         execute_model_req: Optional[ExecuteModelRequest] = None
         execute_model_req: Optional[ExecuteModelRequest] = None
     ) -> List[Union[SamplerOutput, PoolerOutput]]:
     ) -> List[Union[SamplerOutput, PoolerOutput]]:
+        if not self.is_driver_worker:
+            self._execute_model_non_driver()
+            return []
 
 
         if execute_model_req is None:
         if execute_model_req is None:
-            seq_group_metadata_list = None
-        else:
-            seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+            # This signals that there's no more requests to process for now.
+            # All workers are running infinite loop with broadcast_tensor_dict,
+            # and it stops the loop when the driver broadcasts an empty input.
+            # Send an empty input to notify all other workers to stop their
+            # execution loop.
+            broadcast_tensor_dict({}, src=0)
+            return []
 
 
-        blocks_to_swap_in: torch.Tensor
-        blocks_to_swap_out: torch.Tensor
-        blocks_to_copy: torch.Tensor
-        if self.is_driver_worker:
-            assert seq_group_metadata_list is not None
-            assert execute_model_req is not None
-            num_seq_groups = len(seq_group_metadata_list)
-            # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
-            # they contain parameters to launch cudamemcpyasync.
-            blocks_to_swap_in = torch.tensor(
-                execute_model_req.blocks_to_swap_in,
-                device="cpu",
-                dtype=torch.int64).view(-1, 2)
-            blocks_to_swap_out = torch.tensor(
-                execute_model_req.blocks_to_swap_out,
-                device="cpu",
-                dtype=torch.int64).view(-1, 2)
-            # `blocks_to_copy` is a gpu tensor. The src and tgt of
-            # blocks to copy are in the same device, and `blocks_to_copy`
-            # can be used directly within cuda kernels.
-            blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
-                                          device=self.device,
+        seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+        num_seq_groups = len(seq_group_metadata_list)
+        # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
+        # they contain parameters to launch cudamemcpyasync.
+        blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
+                                         device="cpu",
+                                         dtype=torch.int64).view(-1, 2)
+        blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
+                                          device="cpu",
                                           dtype=torch.int64).view(-1, 2)
                                           dtype=torch.int64).view(-1, 2)
-            data: Dict[str, Any] = {
-                "num_seq_groups": num_seq_groups,
-                "blocks_to_swap_in": blocks_to_swap_in,
-                "blocks_to_swap_out": blocks_to_swap_out,
-                "blocks_to_copy": blocks_to_copy,
-            }
-            broadcast_tensor_dict(data, src=0)
-        else:
-            data = broadcast_tensor_dict(src=0)
-            num_seq_groups = data["num_seq_groups"]
-            blocks_to_swap_in = data["blocks_to_swap_in"]
-            blocks_to_swap_out = data["blocks_to_swap_out"]
-            blocks_to_copy = data["blocks_to_copy"]
+        # `blocks_to_copy` is a gpu tensor. The src and tgt of
+        # blocks to copy are in the same device, and `blocks_to_copy`
+        # can be used directly within cuda kernels.
+        blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
+                                      device=self.device,
+                                      dtype=torch.int64).view(-1, 2)
+        data: Dict[str, Any] = {
+            "num_seq_groups": num_seq_groups,
+            "blocks_to_swap_in": blocks_to_swap_in,
+            "blocks_to_swap_out": blocks_to_swap_out,
+            "blocks_to_copy": blocks_to_copy,
+        }
+        broadcast_tensor_dict(data, src=0)
 
 
         self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
         self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
 
 
@@ -285,6 +279,37 @@ class Worker(WorkerBase):
         # to conform to interface.
         # to conform to interface.
         return [output]
         return [output]
 
 
+    @torch.inference_mode()
+    def start_worker_execution_loop(self) -> None:
+        """Execute model loop in parallel worker.
+        You can stop the loop by executing a driver worker with an empty output.
+        See `stop_remote_worker_execution_loop` for more details.
+        """
+        while self._execute_model_non_driver():
+            pass
+
+    def _execute_model_non_driver(self) -> bool:
+        """Execute model in parallel worker.
+        Returns True iff there are remaining sequences to process.
+        """
+        assert not self.is_driver_worker
+        data = broadcast_tensor_dict(src=0)
+        if not data:
+            return False
+
+        num_seq_groups = data.get("num_seq_groups", 0)
+        blocks_to_swap_in = data.get("blocks_to_swap_in")
+        blocks_to_swap_out = data.get("blocks_to_swap_out")
+        blocks_to_copy = data.get("blocks_to_copy")
+        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
+
+        # If there is no input, we don't need to execute the model.
+        if num_seq_groups == 0:
+            return False
+
+        self.model_runner.execute_model(None, self.gpu_cache)
+        return True
+
     def add_lora(self, lora_request: LoRARequest) -> bool:
     def add_lora(self, lora_request: LoRARequest) -> bool:
         return self.model_runner.add_lora(lora_request)
         return self.model_runner.add_lora(lora_request)
 
 

+ 4 - 3
aphrodite/task_handler/worker_base.py

@@ -1,7 +1,7 @@
 import importlib
 import importlib
 import os
 import os
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Dict, List, Set, Tuple
+from typing import Dict, List, Optional, Set, Tuple
 
 
 from loguru import logger
 from loguru import logger
 
 
@@ -47,8 +47,9 @@ class WorkerBase(ABC):
 
 
     @abstractmethod
     @abstractmethod
     def execute_model(
     def execute_model(
-            self,
-            execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
+        self,
+        execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
         """Executes at least one model step on the given sequences, unless no
         """Executes at least one model step on the given sequences, unless no
         sequences are provided."""
         sequences are provided."""
         raise NotImplementedError
         raise NotImplementedError