ray_utils.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import pickle
  2. from typing import List, Optional, Tuple
  3. from loguru import logger
  4. from aphrodite.common.config import ParallelConfig
  5. from aphrodite.common.utils import get_ip, is_hip
  6. from aphrodite.task_handler.worker_base import WorkerWrapperBase
  7. try:
  8. import ray
  9. class RayWorkerWrapper(WorkerWrapperBase):
  10. """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
  11. lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
  12. def __init__(self, *args, **kwargs) -> None:
  13. super().__init__(*args, **kwargs)
  14. # Since the compiled DAG runs a main execution
  15. # in a different thread that calls cuda.set_device.
  16. # The flag indicates is set_device is called on
  17. # that thread.
  18. self.compiled_dag_cuda_device_set = False
  19. def get_node_ip(self) -> str:
  20. return get_ip()
  21. def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
  22. node_id = ray.get_runtime_context().get_node_id()
  23. gpu_ids = ray.get_gpu_ids()
  24. return node_id, gpu_ids
  25. def execute_model_compiled_dag_remote(self, ignored):
  26. """Used only when compiled DAG is enabled."""
  27. import torch
  28. if not self.compiled_dag_cuda_device_set:
  29. torch.cuda.set_device(self.worker.device)
  30. self.compiled_dag_cuda_device_set = True
  31. output = self.worker.execute_model()
  32. output = pickle.dumps(output)
  33. return output
  34. except ImportError as e:
  35. logger.warning(f"Failed to import Ray with {e!r}. "
  36. "For distributed inference, please install Ray with "
  37. "`pip install ray`.")
  38. ray = None # type: ignore
  39. RayWorkerWrapper = None # type: ignore
  40. def initialize_ray_cluster(
  41. parallel_config: ParallelConfig,
  42. ray_address: Optional[str] = None,
  43. ):
  44. """Initialize the distributed cluster with Ray.
  45. it will connect to the Ray cluster and create a placement group
  46. for the workers, which includes the specification of the resources
  47. for each distributed worker.
  48. Args:
  49. parallel_config: The configurations for parallel execution.
  50. ray_address: The address of the Ray cluster. If None, uses
  51. the default Ray cluster address.
  52. """
  53. if ray is None:
  54. raise ImportError(
  55. "Ray is not installed. Please install Ray to use distributed "
  56. "serving.")
  57. # Connect to a ray cluster.
  58. if is_hip():
  59. ray.init(address=ray_address,
  60. ignore_reinit_error=True,
  61. num_gpus=parallel_config.world_size)
  62. else:
  63. ray.init(address=ray_address, ignore_reinit_error=True)
  64. if parallel_config.placement_group:
  65. # Placement group is already set.
  66. return
  67. # Create placement group for worker processes
  68. current_placement_group = ray.util.get_current_placement_group()
  69. if current_placement_group:
  70. # We are in a placement group
  71. bundles = current_placement_group.bundle_specs
  72. # Verify that we can use the placement group.
  73. gpu_bundles = 0
  74. for bundle in bundles:
  75. bundle_gpus = bundle.get("GPU", 0)
  76. if bundle_gpus > 1:
  77. raise ValueError(
  78. "Placement group bundle cannot have more than 1 GPU.")
  79. if bundle_gpus:
  80. gpu_bundles += 1
  81. if parallel_config.world_size > gpu_bundles:
  82. raise ValueError(
  83. "The number of required GPUs exceeds the total number of "
  84. "available GPUs in the placement group.")
  85. else:
  86. num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
  87. if parallel_config.world_size > num_gpus_in_cluster:
  88. raise ValueError(
  89. "The number of required GPUs exceeds the total number of "
  90. "available GPUs in the cluster.")
  91. # Create a new placement group
  92. placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
  93. current_placement_group = ray.util.placement_group(
  94. placement_group_specs)
  95. # Wait until PG is ready - this will block until all
  96. # requested resources are available, and will timeout
  97. # if they cannot be provisioned.
  98. ray.get(current_placement_group.ready(), timeout=1800)
  99. # Set the placement group in the parallel config
  100. parallel_config.placement_group = current_placement_group