ray_tools.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from typing import Optional, List, Tuple, TYPE_CHECKING
  2. from aphrodite.common.config import ParallelConfig
  3. from aphrodite.common.logger import init_logger
  4. from aphrodite.common.utils import is_hip, set_cuda_visible_devices, get_ip
  5. logger = init_logger(__name__)
  6. try:
  7. import ray
  8. class RayWorkerAphrodite:
  9. """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
  10. lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
  11. def __init__(self, init_cached_hf_modules=False) -> None:
  12. if init_cached_hf_modules:
  13. # pylint: disable=import-outside-toplevel
  14. from transformers.dynamic_module_utils import init_hf_modules
  15. init_hf_modules()
  16. self.worker = None
  17. def init_worker(self, worker_init_fn):
  18. self.worker = worker_init_fn()
  19. def __getattr__(self, name):
  20. return getattr(self.worker, name)
  21. def execute_method(self, method, *args, **kwargs):
  22. executor = getattr(self, method)
  23. return executor(*args, **kwargs)
  24. def get_node_ip(self) -> str:
  25. return get_ip()
  26. def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
  27. node_id = ray.get_runtime_context().get_node_id()
  28. gpu_ids = ray.get_gpu_ids()
  29. return node_id, gpu_ids
  30. def set_cuda_visible_devices(self, device_ids) -> None:
  31. set_cuda_visible_devices(device_ids)
  32. except ImportError as e:
  33. logger.warning(f"Failed to import Ray with {e!r}. "
  34. "For distributed inference, please install Ray with "
  35. "`pip install ray`.")
  36. ray = None
  37. RayWorkerAphrodite = None
  38. if TYPE_CHECKING:
  39. from ray.util.placement_group import PlacementGroup
  40. def initialize_cluster(
  41. parallel_config: ParallelConfig,
  42. engine_use_ray: bool = False,
  43. ray_address: Optional[str] = None,
  44. ) -> Optional["PlacementGroup"]:
  45. """Initialize the distributed cluster probably with Ray.
  46. Args:
  47. parallel_config: The configurations for parallel execution.
  48. engine_use_ray: Whether to use Ray for async engine.
  49. ray_address: The address of the Ray cluster. If None, uses
  50. the default Ray cluster address.
  51. Returns:
  52. A tuple of (`distributed_init_method`, `placement_group`). The
  53. `distributed_init_method` is the address for initializing the
  54. distributed backend. `placement_group` includes the specification
  55. of the resources for each distributed worker.
  56. """
  57. if parallel_config.worker_use_ray or engine_use_ray:
  58. if ray is None:
  59. raise ImportError(
  60. "Ray is not installed. Please install Ray to use distributed "
  61. "serving.")
  62. # Connect to a ray cluster.
  63. if is_hip():
  64. ray.init(address=ray_address,
  65. ignore_reinit_error=True,
  66. num_gpus=parallel_config.world_size)
  67. else:
  68. ray.init(address=ray_address, ignore_reinit_error=True)
  69. if not parallel_config.worker_use_ray:
  70. assert parallel_config.world_size == 1, (
  71. "Ray is required if parallel_config.world_size > 1.")
  72. return None
  73. # Create placement group for worker processes
  74. current_placement_group = ray.util.get_current_placement_group()
  75. if current_placement_group:
  76. # We are in a placement group
  77. bundles = current_placement_group.bundle_specs
  78. # Verify that we can use the placement group.
  79. gpu_bundles = 0
  80. for bundle in bundles:
  81. bundle_gpus = bundle.get("GPU", 0)
  82. if bundle_gpus > 1:
  83. raise ValueError(
  84. "Placement group bundle cannot have more than 1 GPU.")
  85. if bundle_gpus:
  86. gpu_bundles += 1
  87. if parallel_config.world_size > gpu_bundles:
  88. raise ValueError(
  89. "The number of required GPUs exceeds the total number of "
  90. "available GPUs in the placement group.")
  91. else:
  92. num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
  93. if parallel_config.world_size > num_gpus_in_cluster:
  94. raise ValueError(
  95. "The number of required GPUs exceeds the total number of "
  96. "available GPUs in the cluster.")
  97. # Create a new placement group
  98. placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
  99. current_placement_group = ray.util.placement_group(
  100. placement_group_specs)
  101. # Wait until PG is ready - this will block until all
  102. # requested resources are available, and will timeout
  103. # if they cannot be provisioned.
  104. ray.get(current_placement_group.ready(), timeout=1800)
  105. return current_placement_group