ray_tpu_executor.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import asyncio
  2. import os
  3. from collections import defaultdict
  4. from itertools import islice, repeat
  5. from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple,
  6. Union)
  7. from loguru import logger
  8. import aphrodite.common.envs as envs
  9. from aphrodite.common.sequence import ExecuteModelRequest
  10. from aphrodite.common.utils import (get_aphrodite_instance_id,
  11. get_distributed_init_method, get_ip,
  12. get_open_port, make_async)
  13. from aphrodite.executor.executor_base import ExecutorAsyncBase
  14. from aphrodite.executor.ray_utils import RayWorkerWrapper, ray
  15. from aphrodite.executor.tpu_executor import TPUExecutor
  16. from aphrodite.modeling.layers.sampler import SamplerOutput
  17. if ray is not None:
  18. from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
  19. if TYPE_CHECKING:
  20. from ray.util.placement_group import PlacementGroup
  21. APHRODITE_TRACE_FUNCTION = envs.APHRODITE_TRACE_FUNCTION
  22. class RayTPUExecutor(TPUExecutor):
  23. def __init__(self, *args, **kwargs):
  24. # This is non-None when the execute model loop is running
  25. # in the parallel workers. It's a coroutine in the AsyncAphrodite case.
  26. self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None
  27. # Updated by implementations that require additional args to be passed
  28. # to the _run_workers execute_model call
  29. self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {}
  30. super().__init__(*args, **kwargs)
  31. def _init_executor(self) -> None:
  32. assert self.parallel_config.distributed_executor_backend == "ray"
  33. placement_group = self.parallel_config.placement_group
  34. # Disable Ray usage stats collection.
  35. ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
  36. if ray_usage != "1":
  37. os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
  38. # Create the parallel TPU workers.
  39. self._init_workers_ray(placement_group)
  40. def _init_workers_ray(self, placement_group: "PlacementGroup",
  41. **ray_remote_kwargs):
  42. # The driver dummy worker does not actually use any resources.
  43. # It holds the resource for the driver worker.
  44. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
  45. # The remaining workers are the actual ray actors.
  46. self.workers: List[RayWorkerWrapper] = []
  47. # Create the workers.
  48. driver_ip = get_ip()
  49. for bundle_id, bundle in enumerate(placement_group.bundle_specs):
  50. if not bundle.get("TPU", 0):
  51. continue
  52. scheduling_strategy = PlacementGroupSchedulingStrategy(
  53. placement_group=placement_group,
  54. placement_group_capture_child_tasks=True,
  55. placement_group_bundle_index=bundle_id,
  56. )
  57. assert self.speculative_config is None
  58. worker_module_name = "aphrodite.task_handler.tpu_worker"
  59. worker_class_name = "TPUWorker"
  60. # GKE does not fetch environment information from metadata server
  61. # and instead sets these from within the Ray process. Therefore we
  62. # need to override the Ray environment variables manually.
  63. override_env = {}
  64. if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
  65. override_env.update({
  66. "TPU_CHIPS_PER_HOST_BOUNDS":
  67. os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
  68. })
  69. if "TPU_HOST_BOUNDS" in os.environ:
  70. override_env.update(
  71. {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
  72. worker = ray.remote(
  73. num_cpus=0,
  74. resources={"TPU": 1},
  75. scheduling_strategy=scheduling_strategy,
  76. **ray_remote_kwargs,
  77. )(RayWorkerWrapper).remote(
  78. worker_module_name=worker_module_name,
  79. worker_class_name=worker_class_name,
  80. trust_remote_code=self.model_config.trust_remote_code,
  81. )
  82. if override_env:
  83. worker.override_env_vars.remote(override_env)
  84. worker_ip = ray.get(worker.get_node_ip.remote())
  85. if worker_ip == driver_ip and self.driver_dummy_worker is None:
  86. # If the worker is on the same node as the driver, we use it
  87. # as the resource holder for the driver process.
  88. self.driver_dummy_worker = worker
  89. self.driver_worker = RayWorkerWrapper(
  90. worker_module_name=worker_module_name,
  91. worker_class_name=worker_class_name,
  92. trust_remote_code=self.model_config.trust_remote_code,
  93. )
  94. else:
  95. # Else, added to the list of workers.
  96. self.workers.append(worker)
  97. if self.driver_dummy_worker is None:
  98. raise ValueError(
  99. "Ray does not allocate any TPUs on the driver node. Consider "
  100. "adjusting the Ray placement group or running the driver on a "
  101. "TPU node.")
  102. # Get the set of TPU IDs used on each node.
  103. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
  104. use_dummy_driver=True)
  105. node_workers = defaultdict(list)
  106. for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
  107. node_workers[node_id].append(i)
  108. APHRODITE_INSTANCE_ID = get_aphrodite_instance_id()
  109. # Set environment variables for the driver and workers.
  110. all_args_to_update_environment_variables = [({
  111. "APHRODITE_INSTANCE_ID":
  112. APHRODITE_INSTANCE_ID,
  113. "APHRODITE_TRACE_FUNCTION":
  114. str(APHRODITE_TRACE_FUNCTION),
  115. }, ) for _ in worker_node_and_gpu_ids]
  116. self._run_workers("update_environment_variables",
  117. all_args=all_args_to_update_environment_variables)
  118. if len(node_workers) == 1:
  119. # in single node case, we don't need to get the IP address.
  120. # the loopback address is sufficient
  121. # NOTE: a node may have several IP addresses, one for each
  122. # network interface. `get_ip()` might return any of them,
  123. # while they might not work for communication inside the node
  124. # if the network setup is complicated. Using the loopback address
  125. # solves this issue, as it always works for communication inside
  126. # the node.
  127. driver_ip = "127.0.0.1"
  128. distributed_init_method = get_distributed_init_method(
  129. driver_ip, get_open_port())
  130. # Initialize the actual workers inside worker wrapper.
  131. init_worker_all_kwargs = [
  132. self._get_worker_kwargs(
  133. local_rank=node_workers[node_id].index(rank),
  134. rank=rank,
  135. distributed_init_method=distributed_init_method,
  136. ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
  137. ]
  138. self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
  139. self._run_workers("init_device")
  140. self._run_workers("load_model",
  141. max_concurrent_workers=self.parallel_config.
  142. max_parallel_loading_workers)
  143. def _driver_execute_model(
  144. self,
  145. execute_model_req: Optional[ExecuteModelRequest] = None
  146. ) -> List[SamplerOutput]:
  147. """Run execute_model in the driver worker.
  148. Passing None will cause the driver to stop the model execution
  149. loop running in each of the remote workers.
  150. """
  151. return self.driver_worker.execute_method("execute_model",
  152. execute_model_req)
  153. def _run_workers(
  154. self,
  155. method: str,
  156. *args,
  157. async_run_remote_workers_only: bool = False,
  158. all_args: Optional[List[Tuple[Any, ...]]] = None,
  159. all_kwargs: Optional[List[Dict[str, Any]]] = None,
  160. use_dummy_driver: bool = False,
  161. max_concurrent_workers: Optional[int] = None,
  162. use_ray_compiled_dag: bool = False,
  163. **kwargs,
  164. ) -> Any:
  165. """Runs the given method on all workers. Can be used in the following
  166. ways:
  167. - async_run_remote_workers_only: If True the method will be run only
  168. in the remote workers, not the driver worker. It will also be
  169. run asynchronously and return a list of futures rather than blocking
  170. on the results.
  171. - args/kwargs: All workers share the same args/kwargs
  172. - all_args/all_kwargs: args/kwargs for each worker are specified
  173. individually
  174. """
  175. if max_concurrent_workers:
  176. raise NotImplementedError(
  177. "max_concurrent_workers is not supported yet.")
  178. count = len(self.workers)
  179. all_worker_args = repeat(args, count) if all_args is None \
  180. else islice(all_args, 1, None)
  181. all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
  182. else islice(all_kwargs, 1, None)
  183. # Start the ray workers first.
  184. ray_worker_outputs = [
  185. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
  186. for (worker, worker_args, worker_kwargs
  187. ) in zip(self.workers, all_worker_args, all_worker_kwargs)
  188. ]
  189. if async_run_remote_workers_only:
  190. # Just return futures
  191. return ray_worker_outputs
  192. driver_args = args if all_args is None else all_args[0]
  193. driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
  194. # Start the driver worker after all the ray workers.
  195. if not use_dummy_driver:
  196. driver_worker_output = self.driver_worker.execute_method(
  197. method, *driver_args, **driver_kwargs)
  198. else:
  199. assert self.driver_dummy_worker is not None
  200. driver_worker_output = ray.get(
  201. self.driver_dummy_worker.execute_method.remote(
  202. method, *driver_args, **driver_kwargs))
  203. # Get the results of the ray workers.
  204. if self.workers:
  205. ray_worker_outputs = ray.get(ray_worker_outputs)
  206. return [driver_worker_output] + ray_worker_outputs
  207. def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
  208. """Wait for futures returned from _run_workers() with
  209. async_run_remote_workers_only to complete."""
  210. ray.get(parallel_worker_tasks)
  211. def determine_num_available_blocks(self) -> Tuple[int, int]:
  212. num_blocks = self._run_workers("determine_num_available_blocks", )
  213. num_tpu_blocks = min(b[0] for b in num_blocks)
  214. num_cpu_blocks = min(b[1] for b in num_blocks)
  215. return num_tpu_blocks, num_cpu_blocks
  216. def initialize_cache(self, num_gpu_blocks: int,
  217. num_cpu_blocks: int) -> None:
  218. logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
  219. num_cpu_blocks)
  220. self.cache_config.num_gpu_blocks = num_gpu_blocks
  221. self.cache_config.num_cpu_blocks = num_cpu_blocks
  222. self._run_workers("initialize_cache",
  223. num_gpu_blocks=num_gpu_blocks,
  224. num_cpu_blocks=num_cpu_blocks)
  225. def execute_model(
  226. self,
  227. execute_model_req: ExecuteModelRequest,
  228. ) -> List[SamplerOutput]:
  229. if self.parallel_worker_tasks is None:
  230. self.parallel_worker_tasks = self._run_workers(
  231. "start_worker_execution_loop",
  232. async_run_remote_workers_only=True,
  233. **self.extra_execute_model_run_workers_kwargs)
  234. # Only the driver worker returns the sampling results.
  235. return self._driver_execute_model(execute_model_req)
  236. def stop_remote_worker_execution_loop(self) -> None:
  237. if self.parallel_worker_tasks is None:
  238. return
  239. self._driver_execute_model()
  240. parallel_worker_tasks = self.parallel_worker_tasks
  241. self.parallel_worker_tasks = None
  242. # Ensure that workers exit model loop cleanly
  243. # (this will raise otherwise)
  244. self._wait_for_tasks_completion(parallel_worker_tasks)
  245. class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase):
  246. def __init__(self, *args, **kwargs):
  247. super().__init__(*args, **kwargs)
  248. self.driver_exec_method = make_async(self.driver_worker.execute_method)
  249. async def execute_model_async(
  250. self,
  251. execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
  252. if self.parallel_worker_tasks is None:
  253. # Start model execution loop running in the parallel workers
  254. self.parallel_worker_tasks = asyncio.create_task(
  255. self._start_worker_execution_loop())
  256. # Only the driver worker returns the sampling results.
  257. return await self._driver_execute_model_async(execute_model_req)
  258. async def stop_remote_worker_execution_loop_async(self) -> None:
  259. if self.parallel_worker_tasks is None:
  260. return
  261. await self._driver_execute_model_async()
  262. parallel_worker_tasks = self.parallel_worker_tasks
  263. self.parallel_worker_tasks = None
  264. # Ensure that workers exit model loop cleanly
  265. # (this will raise otherwise)
  266. await parallel_worker_tasks
  267. async def _driver_execute_model_async(
  268. self,
  269. execute_model_req: Optional[ExecuteModelRequest] = None
  270. ) -> List[SamplerOutput]:
  271. return await self.driver_exec_method("execute_model",
  272. execute_model_req)
  273. async def _start_worker_execution_loop(self):
  274. coros = [
  275. worker.execute_method.remote("start_worker_execution_loop")
  276. for worker in self.workers
  277. ]
  278. return await asyncio.gather(*coros)