ray_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import pickle
  2. from typing import List, Optional, Tuple
  3. from loguru import logger
  4. from aphrodite.common.config import ParallelConfig
  5. from aphrodite.common.utils import get_ip, is_hip
  6. from aphrodite.task_handler.worker_base import WorkerWrapperBase
  7. try:
  8. import ray
  9. class RayWorkerWrapper(WorkerWrapperBase):
  10. """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
  11. lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
  12. def __init__(self, *args, **kwargs) -> None:
  13. super().__init__(*args, **kwargs)
  14. # Since the compiled DAG runs a main execution
  15. # in a different thread that calls cuda.set_device.
  16. # The flag indicates is set_device is called on
  17. # that thread.
  18. self.compiled_dag_cuda_device_set = False
  19. def get_node_ip(self) -> str:
  20. return get_ip()
  21. def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
  22. node_id = ray.get_runtime_context().get_node_id()
  23. gpu_ids = ray.get_gpu_ids()
  24. return node_id, gpu_ids
  25. def execute_model_compiled_dag_remote(self, ignored):
  26. """Used only when compiled DAG is enabled."""
  27. import torch
  28. if not self.compiled_dag_cuda_device_set:
  29. torch.cuda.set_device(self.worker.device)
  30. self.compiled_dag_cuda_device_set = True
  31. output = self.worker.execute_model()
  32. output = pickle.dumps(output)
  33. return output
  34. except ImportError as e:
  35. ray = None # type: ignore
  36. RayWorkerWrapper = None # type: ignore
  37. def initialize_ray_cluster(
  38. parallel_config: ParallelConfig,
  39. ray_address: Optional[str] = None,
  40. ):
  41. """Initialize the distributed cluster with Ray.
  42. it will connect to the Ray cluster and create a placement group
  43. for the workers, which includes the specification of the resources
  44. for each distributed worker.
  45. Args:
  46. parallel_config: The configurations for parallel execution.
  47. ray_address: The address of the Ray cluster. If None, uses
  48. the default Ray cluster address.
  49. """
  50. if ray is None:
  51. raise ImportError(
  52. "Ray is not installed. Please install Ray to use multi-node "
  53. "serving. You can install Ray by running "
  54. "`pip install aphrodite-engine[\"ray\"]`.")
  55. # Connect to a ray cluster.
  56. if is_hip():
  57. ray.init(address=ray_address,
  58. ignore_reinit_error=True,
  59. num_gpus=parallel_config.world_size)
  60. else:
  61. ray.init(address=ray_address, ignore_reinit_error=True)
  62. if parallel_config.placement_group:
  63. # Placement group is already set.
  64. return
  65. # Create placement group for worker processes
  66. current_placement_group = ray.util.get_current_placement_group()
  67. if current_placement_group:
  68. # We are in a placement group
  69. bundles = current_placement_group.bundle_specs
  70. # Verify that we can use the placement group.
  71. gpu_bundles = 0
  72. for bundle in bundles:
  73. bundle_gpus = bundle.get("GPU", 0)
  74. if bundle_gpus > 1:
  75. raise ValueError(
  76. "Placement group bundle cannot have more than 1 GPU.")
  77. if bundle_gpus:
  78. gpu_bundles += 1
  79. if parallel_config.world_size > gpu_bundles:
  80. raise ValueError(
  81. "The number of required GPUs exceeds the total number of "
  82. "available GPUs in the placement group.")
  83. else:
  84. num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
  85. if parallel_config.world_size > num_gpus_in_cluster:
  86. raise ValueError(
  87. "The number of required GPUs exceeds the total number of "
  88. "available GPUs in the cluster.")
  89. # Create a new placement group
  90. placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
  91. current_placement_group = ray.util.placement_group(
  92. placement_group_specs)
  93. # Wait until PG is ready - this will block until all
  94. # requested resources are available, and will timeout
  95. # if they cannot be provisioned.
  96. ray.get(current_placement_group.ready(), timeout=1800)
  97. # Set the placement group in the parallel config
  98. parallel_config.placement_group = current_placement_group