ray_tpu_executor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
  6. Union)
  7. from loguru import logger
  8. import aphrodite.common.envs as envs
  9. from aphrodite.common.sequence import ExecuteModelRequest
  10. from aphrodite.common.utils import (get_aphrodite_instance_id,
  11. get_distributed_init_method, get_ip,
  12. get_open_port, make_async)
  13. from aphrodite.executor.executor_base import ExecutorAsyncBase
  14. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  15. from aphrodite.executor.tpu_executor import TPUExecutor
  16. from aphrodite.modeling.layers.sampler import SamplerOutput
  17. if ray is not None:
  18. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  19. if TYPE_CHECKING:
  20. from ray.util.placement_group import PlacementGroup
  21. APHRODITE_TRACE_FUNCTION = envs.APHRODITE_TRACE_FUNCTION
  22. class RayTPUExecutor(TPUExecutor):
  23. def __init__(self, *args, **kwargs):
  24. # This is non-None when the execute model loop is running
  25. # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
  26. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  27. # Updated by implementations that require additional args to be passed
  28. # to the _run_workers execute_model call
  29. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  30. super().__init__(*args, **kwargs)
  31. def _init_executor(self) -> None:
  32. assert self.parallel_config.distributed_executor_backend == "ray"
  33. placement_group = self.parallel_config.placement_group
  34. # Disable Ray usage stats collection.
  35. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  36. if ray_usage != "1":
  37. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  38. # Create the parallel TPU workers.
  39. self._init_workers_ray(placement_group)
  40. def _init_workers_ray(self, placement_group: "PlacementGroup",
  41. **ray_remote_kwargs):
  42. # The driver dummy worker does not actually use any resources.
  43. # It holds the resource for the driver worker.
  44. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  45. # The remaining workers are the actual ray actors.
  46. self.workers: List[RayWorkerWrapper] = []
  47. # Create the workers.
  48. driver_ip = get_ip()
  49. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  50. if not bundle.get("TPU", 0):
  51. continue
  52. scheduling_strategy = PlacementGroupSchedulingStrategy(
  53. placement_group=placement_group,
  54. placement_group_capture_child_tasks=True,
  55. placement_group_bundle_index=bundle_id,
  56. )
  57. assert self.speculative_config is None
  58. worker_module_name = "aphrodite.worker.tpu_worker"
  59. worker_class_name = "TPUWorker"
  60. # GKE does not fetch environment information from metadata server
  61. # and instead sets these from within the Ray process. Therefore we
  62. # need to override the Ray environment variables manually.
  63. override_env = {}
  64. if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
  65. override_env.update({
  66. "TPU_CHIPS_PER_HOST_BOUNDS":
  67. os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
  68. })
  69. if "TPU_HOST_BOUNDS" in os.environ:
  70. override_env.update(
  71. {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
  72. worker = ray.remote(
  73. num_cpus=0,
  74. resources={"TPU": 1},
  75. scheduling_strategy=scheduling_strategy,
  76. **ray_remote_kwargs,
  77. )(RayWorkerWrapper).remote(
  78. worker_module_name=worker_module_name,
  79. worker_class_name=worker_class_name,
  80. trust_remote_code=self.model_config.trust_remote_code,
  81. )
  82. if override_env:
  83. worker.override_env_vars.remote(override_env)
  84. worker_ip = ray.get(worker.get_node_ip.remote())
  85. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  86. # If the worker is on the same node as the driver, we use it
  87. # as the resource holder for the driver process.
  88. self.driver_dummy_worker = worker
  89. self.driver_worker = RayWorkerWrapper(
  90. worker_module_name=worker_module_name,
  91. worker_class_name=worker_class_name,
  92. trust_remote_code=self.model_config.trust_remote_code,
  93. )
  94. else:
  95. # Else, added to the list of workers.
  96. self.workers.append(worker)
  97. logger.debug(f"workers: {self.workers}")
  98. logger.debug(f"driver_dummy_worker: {self.driver_dummy_worker}")
  99. if self.driver_dummy_worker is None:
  100. raise ValueError(
  101. "Ray does not allocate any TPUs on the driver node. Consider "
  102. "adjusting the Ray placement group or running the driver on a "
  103. "TPU node.")
  104. worker_ips = [
  105. ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
  106. for worker in self.workers
  107. ]
  108. ip_counts: Dict[str, int] = {}
  109. for ip in worker_ips:
  110. ip_counts[ip] = ip_counts.get(ip, 0) + 1
  111. def sort_by_driver_then_worker_ip(worker):
  112. """
  113. Sort the workers based on 3 properties:
  114. 1. If the worker is on the same node as the driver (vllm engine),
  115. it should be placed first.
  116. 2. Then, if the worker is on a node with fewer workers, it should
  117. be placed first.
  118. 3. Finally, if the work is on a node with smaller IP address, it
  119. should be placed first.
  120. """
  121. ip = ray.get(worker.get_node_ip.remote())
  122. return (ip != driver_ip, ip_counts[ip], ip)
  123. # After sorting, the workers on the same node will be
  124. # close to each other, and the workers on the driver
  125. # node will be placed first.
  126. self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
  127. # Get the set of TPU IDs used on each node.
  128. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  129. use_dummy_driver=True)
  130. node_workers = defaultdict(list)
  131. for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
  132. node_workers[node_id].append(i)
  133. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  134. # Set environment variables for the driver and workers.
  135. all_args_to_update_environment_variables = [({
  136. "APHRODITE_INSTANCE_ID":
  137. APHRODITE_INSTANCE_ID,
  138. "APHRODITE_TRACE_FUNCTION":
  139. str(APHRODITE_TRACE_FUNCTION),
  140. }, ) for _ in worker_node_and_gpu_ids]
  141. self._run_workers("update_environment_variables",
  142. all_args=all_args_to_update_environment_variables)
  143. if len(node_workers) == 1:
  144. # in single node case, we don't need to get the IP address.
  145. # the loopback address is sufficient
  146. # NOTE: a node may have several IP addresses, one for each
  147. # network interface. `get_ip()` might return any of them,
  148. # while they might not work for communication inside the node
  149. # if the network setup is complicated. Using the loopback address
  150. # solves this issue, as it always works for communication inside
  151. # the node.
  152. driver_ip = "127.0.0.1"
  153. distributed_init_method = get_distributed_init_method(
  154. driver_ip, get_open_port())
  155. # Initialize the actual workers inside worker wrapper.
  156. init_worker_all_kwargs = [
  157. self._get_worker_kwargs(
  158. local_rank=node_workers[node_id].index(rank),
  159. rank=rank,
  160. distributed_init_method=distributed_init_method,
  161. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  162. ]
  163. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  164. self._run_workers("init_device")
  165. self._run_workers("load_model",
  166. max_concurrent_workers=self.parallel_config.
  167. max_parallel_loading_workers)
  168. def _driver_execute_model(
  169. self,
  170. execute_model_req: Optional[ExecuteModelRequest] = None
  171. ) -> List[SamplerOutput]:
  172. """Run execute_model in the driver worker.
  173. Passing None will cause the driver to stop the model execution
  174. loop running in each of the remote workers.
  175. """
  176. return self.driver_worker.execute_method("execute_model",
  177. execute_model_req)
  178. def _run_workers(
  179. self,
  180. method: str,
  181. *args,
  182. async_run_remote_workers_only: bool = False,
  183. all_args: Optional[List[Tuple[Any, ...]]] = None,
  184. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  185. use_dummy_driver: bool = False,
  186. max_concurrent_workers: Optional[int] = None,
  187. use_ray_compiled_dag: bool = False,
  188. **kwargs,
  189. ) -> Any:
  190. """Runs the given method on all workers. Can be used in the following
  191. ways:
  192. - async_run_remote_workers_only: If True the method will be run only
  193. in the remote workers, not the driver worker. It will also be
  194. run asynchronously and return a list of futures rather than blocking
  195. on the results.
  196. - args/kwargs: All workers share the same args/kwargs
  197. - all_args/all_kwargs: args/kwargs for each worker are specified
  198. individually
  199. """
  200. if max_concurrent_workers:
  201. raise NotImplementedError(
  202. "max_concurrent_workers is not supported yet.")
  203. count = len(self.workers)
  204. all_worker_args = repeat(args, count) if all_args is None \
  205. else islice(all_args, 1, None)
  206. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  207. else islice(all_kwargs, 1, None)
  208. # Start the ray workers first.
  209. ray_worker_outputs = [
  210. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  211. for (worker, worker_args, worker_kwargs
  212. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  213. ]
  214. if async_run_remote_workers_only:
  215. # Just return futures
  216. return ray_worker_outputs
  217. driver_args = args if all_args is None else all_args[0]
  218. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  219. # Start the driver worker after all the ray workers.
  220. if not use_dummy_driver:
  221. driver_worker_output = self.driver_worker.execute_method(
  222. method, *driver_args, **driver_kwargs)
  223. else:
  224. assert self.driver_dummy_worker is not None
  225. driver_worker_output = ray.get(
  226. self.driver_dummy_worker.execute_method.remote(
  227. method, *driver_args, **driver_kwargs))
  228. # Get the results of the ray workers.
  229. if self.workers:
  230. ray_worker_outputs = ray.get(ray_worker_outputs)
  231. return [driver_worker_output] + ray_worker_outputs
  232. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  233. """Wait for futures returned from _run_workers() with
  234. async_run_remote_workers_only to complete."""
  235. ray.get(parallel_worker_tasks)
  236. def determine_num_available_blocks(self) -> Tuple[int, int]:
  237. num_blocks = self._run_workers("determine_num_available_blocks", )
  238. num_tpu_blocks = min(b[0] for b in num_blocks)
  239. num_cpu_blocks = min(b[1] for b in num_blocks)
  240. return num_tpu_blocks, num_cpu_blocks
  241. def initialize_cache(self, num_gpu_blocks: int,
  242. num_cpu_blocks: int) -> None:
  243. logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
  244. num_cpu_blocks)
  245. self.cache_config.num_gpu_blocks = num_gpu_blocks
  246. self.cache_config.num_cpu_blocks = num_cpu_blocks
  247. self._run_workers("initialize_cache",
  248. num_gpu_blocks=num_gpu_blocks,
  249. num_cpu_blocks=num_cpu_blocks)
  250. def execute_model(
  251. self,
  252. execute_model_req: ExecuteModelRequest,
  253. ) -> List[SamplerOutput]:
  254. if self.parallel_worker_tasks is None:
  255. self.parallel_worker_tasks = self._run_workers(
  256. "start_worker_execution_loop",
  257. async_run_remote_workers_only=True,
  258. **self.extra_execute_model_run_workers_kwargs)
  259. # Only the driver worker returns the sampling results.
  260. return self._driver_execute_model(execute_model_req)
  261. def stop_remote_worker_execution_loop(self) -> None:
  262. if self.parallel_worker_tasks is None:
  263. return
  264. self._driver_execute_model()
  265. parallel_worker_tasks = self.parallel_worker_tasks
  266. self.parallel_worker_tasks = None
  267. # Ensure that workers exit model loop cleanly
  268. # (this will raise otherwise)
  269. self._wait_for_tasks_completion(parallel_worker_tasks)
  270. class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
  271. def __init__(self, *args, **kwargs):
  272. super().__init__(*args, **kwargs)
  273. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  274. async def execute_model_async(
  275. self,
  276. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  277. if self.parallel_worker_tasks is None:
  278. # Start model execution loop running in the parallel workers
  279. self.parallel_worker_tasks = asyncio.create_task(
  280. self._start_worker_execution_loop())
  281. # Only the driver worker returns the sampling results.
  282. return await self._driver_execute_model_async(execute_model_req)
  283. async def stop_remote_worker_execution_loop_async(self) -> None:
  284. if self.parallel_worker_tasks is None:
  285. return
  286. await self._driver_execute_model_async()
  287. parallel_worker_tasks = self.parallel_worker_tasks
  288. self.parallel_worker_tasks = None
  289. # Ensure that workers exit model loop cleanly
  290. # (this will raise otherwise)
  291. await parallel_worker_tasks
  292. async def _driver_execute_model_async(
  293. self,
  294. execute_model_req: Optional[ExecuteModelRequest] = None
  295. ) -> List[SamplerOutput]:
  296. return await self.driver_exec_method("execute_model",
  297. execute_model_req)
  298. async def _start_worker_execution_loop(self):
  299. coros = [
  300. worker.execute_method.remote("start_worker_execution_loop")
  301. for worker in self.workers
  302. ]
  303. return await asyncio.gather(*coros)