ray_tpu_executor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
  6. Union)
  7. from loguru import logger
  8. import aphrodite.common.envs as envs
  9. from aphrodite.common.sequence import ExecuteModelRequest
  10. from aphrodite.common.utils import (get_aphrodite_instance_id,
  11. get_distributed_init_method, get_ip,
  12. get_open_port, make_async)
  13. from aphrodite.executor.executor_base import ExecutorAsyncBase
  14. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  15. from aphrodite.executor.tpu_executor import TPUExecutor
  16. from aphrodite.modeling.layers.sampler import SamplerOutput
  17. if ray is not None:
  18. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  19. if TYPE_CHECKING:
  20. from ray.util.placement_group import PlacementGroup
  21. APHRODITE_TRACE_FUNCTION = envs.APHRODITE_TRACE_FUNCTION
  22. class RayTPUExecutor(TPUExecutor):
  23. uses_ray: bool = True
  24. def __init__(self, *args, **kwargs):
  25. # This is non-None when the execute model loop is running
  26. # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
  27. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  28. # Updated by implementations that require additional args to be passed
  29. # to the _run_workers execute_model call
  30. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  31. super().__init__(*args, **kwargs)
  32. def _init_executor(self) -> None:
  33. assert self.parallel_config.distributed_executor_backend == "ray"
  34. placement_group = self.parallel_config.placement_group
  35. # Disable Ray usage stats collection.
  36. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  37. if ray_usage != "1":
  38. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  39. # Create the parallel TPU workers.
  40. self._init_workers_ray(placement_group)
  41. def _init_workers_ray(self, placement_group: "PlacementGroup",
  42. **ray_remote_kwargs):
  43. # The driver dummy worker does not actually use any resources.
  44. # It holds the resource for the driver worker.
  45. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  46. # The remaining workers are the actual ray actors.
  47. self.workers: List[RayWorkerWrapper] = []
  48. # Create the workers.
  49. driver_ip = get_ip()
  50. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  51. if not bundle.get("TPU", 0):
  52. continue
  53. scheduling_strategy = PlacementGroupSchedulingStrategy(
  54. placement_group=placement_group,
  55. placement_group_capture_child_tasks=True,
  56. placement_group_bundle_index=bundle_id,
  57. )
  58. assert self.speculative_config is None
  59. if self.scheduler_config.is_multi_step:
  60. worker_module_name = "aphrodite.worker.multi_step_tpu_worker"
  61. worker_class_name = "MultiStepTPUWorker"
  62. else:
  63. worker_module_name = "aphrodite.worker.tpu_worker"
  64. worker_class_name = "TPUWorker"
  65. # GKE does not fetch environment information from metadata server
  66. # and instead sets these from within the Ray process. Therefore we
  67. # need to override the Ray environment variables manually.
  68. override_env = {}
  69. if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
  70. override_env.update({
  71. "TPU_CHIPS_PER_HOST_BOUNDS":
  72. os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
  73. })
  74. if "TPU_HOST_BOUNDS" in os.environ:
  75. override_env.update(
  76. {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
  77. worker = ray.remote(
  78. num_cpus=0,
  79. resources={"TPU": 1},
  80. scheduling_strategy=scheduling_strategy,
  81. **ray_remote_kwargs,
  82. )(RayWorkerWrapper).remote(
  83. worker_module_name=worker_module_name,
  84. worker_class_name=worker_class_name,
  85. trust_remote_code=self.model_config.trust_remote_code,
  86. )
  87. if override_env:
  88. worker.override_env_vars.remote(override_env)
  89. worker_ip = ray.get(worker.get_node_ip.remote())
  90. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  91. # If the worker is on the same node as the driver, we use it
  92. # as the resource holder for the driver process.
  93. self.driver_dummy_worker = worker
  94. self.driver_worker = RayWorkerWrapper(
  95. worker_module_name=worker_module_name,
  96. worker_class_name=worker_class_name,
  97. trust_remote_code=self.model_config.trust_remote_code,
  98. )
  99. else:
  100. # Else, added to the list of workers.
  101. self.workers.append(worker)
  102. logger.debug(f"workers: {self.workers}")
  103. logger.debug(f"driver_dummy_worker: {self.driver_dummy_worker}")
  104. if self.driver_dummy_worker is None:
  105. raise ValueError(
  106. "Ray does not allocate any TPUs on the driver node. Consider "
  107. "adjusting the Ray placement group or running the driver on a "
  108. "TPU node.")
  109. worker_ips = [
  110. ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
  111. for worker in self.workers
  112. ]
  113. ip_counts: Dict[str, int] = {}
  114. for ip in worker_ips:
  115. ip_counts[ip] = ip_counts.get(ip, 0) + 1
  116. def sort_by_driver_then_worker_ip(worker):
  117. """
  118. Sort the workers based on 3 properties:
  119. 1. If the worker is on the same node as the driver (vllm engine),
  120. it should be placed first.
  121. 2. Then, if the worker is on a node with fewer workers, it should
  122. be placed first.
  123. 3. Finally, if the work is on a node with smaller IP address, it
  124. should be placed first.
  125. """
  126. ip = ray.get(worker.get_node_ip.remote())
  127. return (ip != driver_ip, ip_counts[ip], ip)
  128. # After sorting, the workers on the same node will be
  129. # close to each other, and the workers on the driver
  130. # node will be placed first.
  131. self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
  132. # Get the set of TPU IDs used on each node.
  133. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  134. use_dummy_driver=True)
  135. node_workers = defaultdict(list)
  136. for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
  137. node_workers[node_id].append(i)
  138. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  139. # Set environment variables for the driver and workers.
  140. all_args_to_update_environment_variables = [({
  141. "APHRODITE_INSTANCE_ID":
  142. APHRODITE_INSTANCE_ID,
  143. "APHRODITE_TRACE_FUNCTION":
  144. str(APHRODITE_TRACE_FUNCTION),
  145. }, ) for _ in worker_node_and_gpu_ids]
  146. self._run_workers("update_environment_variables",
  147. all_args=all_args_to_update_environment_variables)
  148. if len(node_workers) == 1:
  149. # in single node case, we don't need to get the IP address.
  150. # the loopback address is sufficient
  151. # NOTE: a node may have several IP addresses, one for each
  152. # network interface. `get_ip()` might return any of them,
  153. # while they might not work for communication inside the node
  154. # if the network setup is complicated. Using the loopback address
  155. # solves this issue, as it always works for communication inside
  156. # the node.
  157. driver_ip = "127.0.0.1"
  158. distributed_init_method = get_distributed_init_method(
  159. driver_ip, get_open_port())
  160. # Initialize the actual workers inside worker wrapper.
  161. init_worker_all_kwargs = [
  162. self._get_worker_kwargs(
  163. local_rank=node_workers[node_id].index(rank),
  164. rank=rank,
  165. distributed_init_method=distributed_init_method,
  166. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  167. ]
  168. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  169. self._run_workers("init_device")
  170. self._run_workers("load_model",
  171. max_concurrent_workers=self.parallel_config.
  172. max_parallel_loading_workers)
  173. def _driver_execute_model(
  174. self,
  175. execute_model_req: Optional[ExecuteModelRequest] = None
  176. ) -> List[SamplerOutput]:
  177. """Run execute_model in the driver worker.
  178. Passing None will cause the driver to stop the model execution
  179. loop running in each of the remote workers.
  180. """
  181. return self.driver_worker.execute_method("execute_model",
  182. execute_model_req)
  183. def _run_workers(
  184. self,
  185. method: str,
  186. *args,
  187. async_run_remote_workers_only: bool = False,
  188. all_args: Optional[List[Tuple[Any, ...]]] = None,
  189. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  190. use_dummy_driver: bool = False,
  191. max_concurrent_workers: Optional[int] = None,
  192. use_ray_compiled_dag: bool = False,
  193. **kwargs,
  194. ) -> Any:
  195. """Runs the given method on all workers. Can be used in the following
  196. ways:
  197. - async_run_remote_workers_only: If True the method will be run only
  198. in the remote workers, not the driver worker. It will also be
  199. run asynchronously and return a list of futures rather than blocking
  200. on the results.
  201. - args/kwargs: All workers share the same args/kwargs
  202. - all_args/all_kwargs: args/kwargs for each worker are specified
  203. individually
  204. """
  205. if max_concurrent_workers:
  206. raise NotImplementedError(
  207. "max_concurrent_workers is not supported yet.")
  208. count = len(self.workers)
  209. all_worker_args = repeat(args, count) if all_args is None \
  210. else islice(all_args, 1, None)
  211. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  212. else islice(all_kwargs, 1, None)
  213. # Start the ray workers first.
  214. ray_worker_outputs = [
  215. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  216. for (worker, worker_args, worker_kwargs
  217. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  218. ]
  219. if async_run_remote_workers_only:
  220. # Just return futures
  221. return ray_worker_outputs
  222. driver_args = args if all_args is None else all_args[0]
  223. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  224. # Start the driver worker after all the ray workers.
  225. if not use_dummy_driver:
  226. driver_worker_output = self.driver_worker.execute_method(
  227. method, *driver_args, **driver_kwargs)
  228. else:
  229. assert self.driver_dummy_worker is not None
  230. driver_worker_output = ray.get(
  231. self.driver_dummy_worker.execute_method.remote(
  232. method, *driver_args, **driver_kwargs))
  233. # Get the results of the ray workers.
  234. if self.workers:
  235. ray_worker_outputs = ray.get(ray_worker_outputs)
  236. return [driver_worker_output] + ray_worker_outputs
  237. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  238. """Wait for futures returned from _run_workers() with
  239. async_run_remote_workers_only to complete."""
  240. ray.get(parallel_worker_tasks)
  241. def determine_num_available_blocks(self) -> Tuple[int, int]:
  242. num_blocks = self._run_workers("determine_num_available_blocks", )
  243. num_tpu_blocks = min(b[0] for b in num_blocks)
  244. num_cpu_blocks = min(b[1] for b in num_blocks)
  245. return num_tpu_blocks, num_cpu_blocks
  246. def initialize_cache(self, num_gpu_blocks: int,
  247. num_cpu_blocks: int) -> None:
  248. logger.info(f"# TPU blocks: {num_gpu_blocks}, "
  249. f"# CPU blocks: {num_cpu_blocks}")
  250. self.cache_config.num_gpu_blocks = num_gpu_blocks
  251. self.cache_config.num_cpu_blocks = num_cpu_blocks
  252. self._run_workers("initialize_cache",
  253. num_gpu_blocks=num_gpu_blocks,
  254. num_cpu_blocks=num_cpu_blocks)
  255. def execute_model(
  256. self,
  257. execute_model_req: ExecuteModelRequest,
  258. ) -> List[SamplerOutput]:
  259. if self.parallel_worker_tasks is None:
  260. self.parallel_worker_tasks = self._run_workers(
  261. "start_worker_execution_loop",
  262. async_run_remote_workers_only=True,
  263. **self.extra_execute_model_run_workers_kwargs)
  264. # Only the driver worker returns the sampling results.
  265. return self._driver_execute_model(execute_model_req)
  266. def stop_remote_worker_execution_loop(self) -> None:
  267. if self.parallel_worker_tasks is None:
  268. return
  269. self._driver_execute_model()
  270. parallel_worker_tasks = self.parallel_worker_tasks
  271. self.parallel_worker_tasks = None
  272. # Ensure that workers exit model loop cleanly
  273. # (this will raise otherwise)
  274. self._wait_for_tasks_completion(parallel_worker_tasks)
  275. class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
  276. def __init__(self, *args, **kwargs):
  277. super().__init__(*args, **kwargs)
  278. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  279. async def execute_model_async(
  280. self,
  281. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  282. if self.parallel_worker_tasks is None:
  283. # Start model execution loop running in the parallel workers
  284. self.parallel_worker_tasks = asyncio.create_task(
  285. self._start_worker_execution_loop())
  286. # Only the driver worker returns the sampling results.
  287. return await self._driver_execute_model_async(execute_model_req)
  288. async def stop_remote_worker_execution_loop_async(self) -> None:
  289. if self.parallel_worker_tasks is None:
  290. return
  291. await self._driver_execute_model_async()
  292. parallel_worker_tasks = self.parallel_worker_tasks
  293. self.parallel_worker_tasks = None
  294. # Ensure that workers exit model loop cleanly
  295. # (this will raise otherwise)
  296. await parallel_worker_tasks
  297. async def _driver_execute_model_async(
  298. self,
  299. execute_model_req: Optional[ExecuteModelRequest] = None
  300. ) -> List[SamplerOutput]:
  301. return await self.driver_exec_method("execute_model",
  302. execute_model_req)
  303. async def _start_worker_execution_loop(self):
  304. coros = [
  305. worker.execute_method.remote("start_worker_execution_loop")
  306. for worker in self.workers
  307. ]
  308. return await asyncio.gather(*coros)