ray_utils.py 13 KB


  1. import os
  2. import time
  3. from collections import defaultdict
  4. from typing import Dict, List, Optional, Tuple, Union
  5. import msgspec
  6. from loguru import logger
  7. from aphrodite.common.config import ParallelConfig
  8. from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
  9. from aphrodite.common.utils import get_ip, is_hip, is_xpu
  10. from aphrodite.executor.msgspec_utils import decode_hook, encode_hook
  11. from aphrodite.platforms import current_platform
  12. from aphrodite.task_handler.worker_base import WorkerWrapperBase
  13. PG_WAIT_TIMEOUT = 1800
  14. try:
  15. import ray
  16. from ray._private.state import available_resources_per_node
  17. from ray.util import placement_group_table
  18. from ray.util.placement_group import PlacementGroup
  19. class RayWorkerWrapper(WorkerWrapperBase):
  20. """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
  21. lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
  22. def __init__(self, *args, **kwargs) -> None:
  23. super().__init__(*args, **kwargs)
  24. # Since the compiled DAG runs a main execution
  25. # in a different thread that calls cuda.set_device.
  26. # The flag indicates is set_device is called on
  27. # that thread.
  28. self.compiled_dag_cuda_device_set = False
  29. self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
  30. dec_hook=decode_hook)
  31. self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
  32. def get_node_ip(self) -> str:
  33. return get_ip()
  34. def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
  35. node_id = ray.get_runtime_context().get_node_id()
  36. gpu_ids = ray.get_gpu_ids()
  37. return node_id, gpu_ids
  38. def execute_model_spmd(
  39. self, req_or_tuple: Union[bytes,
  40. Tuple[bytes,
  41. Optional[IntermediateTensors]]]
  42. ) -> bytes:
  43. """Execute model in SPMD fashion: used only when SPMD worker and
  44. compiled DAG are both enabled.
  45. Args:
  46. req_or_tuple: A request or a tuple containing the
  47. request and intermediate tensors. Intermediate tensors are
  48. None unless if it is provided because it is > 0 pipeline
  49. stage. The request is serialized by msgspec.
  50. """
  51. if isinstance(req_or_tuple, bytes):
  52. serialized_req, intermediate_tensors = req_or_tuple, None
  53. else:
  54. serialized_req, intermediate_tensors = req_or_tuple
  55. execute_model_req = self.input_decoder.decode(serialized_req)
  56. # TODO: This is needed right now because Ray DAG executes
  57. # on a background thread, so we need to reset torch's current
  58. # device.
  59. import torch
  60. if not self.compiled_dag_cuda_device_set:
  61. torch.cuda.set_device(self.worker.device)
  62. self.compiled_dag_cuda_device_set = True
  63. output = self.worker._execute_model_spmd(execute_model_req,
  64. intermediate_tensors)
  65. # Pipeline model request and output to the next pipeline stage
  66. if isinstance(output, IntermediateTensors):
  67. output = serialized_req, output
  68. else:
  69. output = self.output_encoder.encode(output)
  70. return output
  71. def override_env_vars(self, vars: Dict[str, str]):
  72. os.environ.update(vars)
  73. ray_import_err = None
  74. except ImportError as e:
  75. ray = None # type: ignore
  76. ray_import_err = e
  77. RayWorkerWrapper = None # type: ignore
  78. def ray_is_available() -> bool:
  79. """Returns True if Ray is available."""
  80. return ray is not None
  81. def assert_ray_available():
  82. """Raise an exception if Ray is not available."""
  83. if ray is None:
  84. raise ValueError("Failed to import Ray, please install Ray with "
  85. "`pip install ray`.") from ray_import_err
  86. def _verify_bundles(placement_group: "PlacementGroup",
  87. parallel_config: ParallelConfig, device_str: str):
  88. """Verify a given placement group has bundles located in the right place.
  89. There are 2 rules.
  90. - Warn if all tensor parallel workers cannot fit in a single node.
  91. - Fail if driver node is not included in a placement group.
  92. """
  93. assert ray.is_initialized(), (
  94. "Ray is not initialized although distributed-executor-backend is ray.")
  95. pg_data = placement_group_table(placement_group)
  96. # bundle_idx -> node_id
  97. bundle_to_node_ids = pg_data["bundles_to_node_id"]
  98. # bundle_idx -> bundle (e.g., {"GPU": 1})
  99. bundles = pg_data["bundles"]
  100. # node_id -> List of bundle (e.g., {"GPU": 1})
  101. node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
  102. for bundle_idx, node_id in bundle_to_node_ids.items():
  103. node_id_to_bundle[node_id].append(bundles[bundle_idx])
  104. driver_node_id = ray.get_runtime_context().get_node_id()
  105. if driver_node_id not in node_id_to_bundle:
  106. raise RuntimeError(
  107. f"driver node id {driver_node_id} is not included in a placement "
  108. f"group {placement_group.id}. Node id -> bundles "
  109. f"{node_id_to_bundle}. "
  110. "You don't have enough GPUs available in a current node. Check "
  111. "`ray status` to see if you have available GPUs in a node "
  112. f"{driver_node_id} before starting an vLLM engine.")
  113. for node_id, bundles in node_id_to_bundle.items():
  114. if len(bundles) < parallel_config.tensor_parallel_size:
  115. logger.warning(
  116. f"tensor_parallel_size={parallel_config.tensor_parallel_size} "
  117. f"is bigger than a reserved number of {device_str}s "
  118. f"({len(bundles)} {device_str}s) in a node {node_id}. "
  119. "Tensor parallel workers can be spread out to 2+ nodes which "
  120. "can degrade the performance unless you have fast interconnect "
  121. "across nodes, like Infiniband. To resolve this issue, make "
  122. "sure you have more than "
  123. f"than {parallel_config.tensor_parallel_size} GPUs available "
  124. "at each node.")
  125. def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
  126. """Wait until a placement group is ready.
  127. It prints the informative log messages if the placement group is
  128. not created within time.
  129. """
  130. # Wait until PG is ready - this will block until all
  131. # requested resources are available, and will timeout
  132. # if they cannot be provisioned.
  133. placement_group_specs = current_placement_group.bundle_specs
  134. s = time.time()
  135. pg_ready_ref = current_placement_group.ready()
  136. wait_interval = 10
  137. while time.time() - s < PG_WAIT_TIMEOUT:
  138. ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval)
  139. if len(ready) > 0:
  140. break
  141. # Exponential backoff for warning print.
  142. wait_interval *= 2
  143. logger.info(
  144. f"Waiting for creating a placement group of specs for "
  145. f"{int(time.time() - s)} seconds. specs={placement_group_specs}. "
  146. "Check `ray status` to see if you have enough resources.")
  147. try:
  148. ray.get(pg_ready_ref, timeout=0)
  149. except ray.exceptions.GetTimeoutError:
  150. raise ValueError(
  151. "Cannot provide a placement group of "
  152. f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See "
  153. "`ray status` to make sure the cluster has enough resources."
  154. ) from None
  155. def _wait_until_pg_removed(current_placement_group: "PlacementGroup"):
  156. ray.util.remove_placement_group(current_placement_group)
  157. s = time.time()
  158. wait_interval = 10
  159. while time.time() - s < PG_WAIT_TIMEOUT:
  160. pg = ray.util.get_current_placement_group()
  161. if pg is None:
  162. break
  163. # Exponential backoff for warning print.
  164. wait_interval *= 2
  165. logger.info(
  166. "Waiting for removing a placement group of specs for "
  167. "%d seconds.", int(time.time() - s))
  168. time.sleep(wait_interval)
  169. def initialize_ray_cluster(
  170. parallel_config: ParallelConfig,
  171. ray_address: Optional[str] = None,
  172. ):
  173. """Initialize the distributed cluster with Ray.
  174. it will connect to the Ray cluster and create a placement group
  175. for the workers, which includes the specification of the resources
  176. for each distributed worker.
  177. Args:
  178. parallel_config: The configurations for parallel execution.
  179. ray_address: The address of the Ray cluster. If None, uses
  180. the default Ray cluster address.
  181. """
  182. assert_ray_available()
  183. # Connect to a ray cluster.
  184. if is_hip() or is_xpu():
  185. ray.init(address=ray_address,
  186. ignore_reinit_error=True,
  187. num_gpus=parallel_config.world_size)
  188. else:
  189. ray.init(address=ray_address, ignore_reinit_error=True)
  190. if parallel_config.placement_group:
  191. # Placement group is already set.
  192. return
  193. device_str = "GPU" if not current_platform.is_tpu() else "TPU"
  194. # Create placement group for worker processes
  195. current_placement_group = ray.util.get_current_placement_group()
  196. if current_placement_group:
  197. # We are in a placement group
  198. bundles = current_placement_group.bundle_specs
  199. # Verify that we can use the placement group.
  200. device_bundles = 0
  201. for bundle in bundles:
  202. bundle_devices = bundle.get(device_str, 0)
  203. if bundle_devices > 1:
  204. raise ValueError(
  205. "Placement group bundle cannot have more than 1 "
  206. f"{device_str}.")
  207. if bundle_devices:
  208. device_bundles += 1
  209. if parallel_config.world_size > device_bundles:
  210. raise ValueError(
  211. f"The number of required {device_str}s exceeds the total "
  212. f"number of available {device_str}s in the placement group."
  213. f"Required number of devices: {parallel_config.world_size}. "
  214. f"Total number of devices: {device_bundles}.")
  215. else:
  216. num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
  217. if parallel_config.world_size > num_devices_in_cluster:
  218. raise ValueError(
  219. f"The number of required {device_str}s exceeds the total "
  220. f"number of available {device_str}s in the placement group.")
  221. # Create a new placement group
  222. placement_group_specs: List[Dict[str, float]] = ([{
  223. device_str: 1.0
  224. } for _ in range(parallel_config.world_size)])
  225. # Aphrodite engine is also a worker to execute model with an accelerator
  226. # so it requires to have the device in a current node. Check if
  227. # the current node has at least one device.
  228. current_ip = get_ip()
  229. current_node_id = ray.get_runtime_context().get_node_id()
  230. current_node_resource = available_resources_per_node()[current_node_id]
  231. if current_node_resource.get(device_str, 0) < 1:
  232. raise ValueError(
  233. f"Current node has no {device_str} available. "
  234. f"{current_node_resource=}. Aphrodite engine cannot start "
  235. f"without {device_str}. Make sure you have at least 1 "
  236. f"{device_str} available in a node {current_node_id=} "
  237. f"{current_ip=}.")
  238. # This way, at least bundle is required to be created in a current
  239. # node.
  240. placement_group_specs[0][f"node:{current_ip}"] = 0.001
  241. # By default, Ray packs resources as much as possible.
  242. current_placement_group = ray.util.placement_group(
  243. placement_group_specs, strategy="PACK")
  244. _wait_until_pg_ready(current_placement_group)
  245. assert current_placement_group is not None
  246. _verify_bundles(current_placement_group, parallel_config, device_str)
  247. # Set the placement group in the parallel config
  248. parallel_config.placement_group = current_placement_group
  249. def get_num_tpu_nodes() -> int:
  250. from ray._private.accelerators import TPUAcceleratorManager
  251. cluster_resources = ray.cluster_resources()
  252. total_tpus = int(cluster_resources["TPU"])
  253. tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
  254. assert total_tpus % tpus_per_node == 0
  255. return total_tpus // tpus_per_node
  256. def get_num_nodes_in_placement_group() -> int:
  257. pg_table = ray.util.placement_group_table()
  258. current_pg = ray.util.get_current_placement_group()
  259. num_nodes = 0
  260. if current_pg:
  261. nodes_in_pg = set()
  262. for pg_key, pg in pg_table.items():
  263. if pg_key == current_pg.id.hex():
  264. for _, node in pg["bundles_to_node_id"].items():
  265. nodes_in_pg.add(node)
  266. num_nodes = len(nodes_in_pg)
  267. return num_nodes