ray_tools.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import pickle
  2. from typing import Optional, List, Tuple
  3. from loguru import logger
  4. from aphrodite.common.config import ParallelConfig
  5. from aphrodite.common.utils import is_hip, set_cuda_visible_devices, get_ip
  6. try:
  7. import ray
  8. class RayWorkerAphrodite:
  9. """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
  10. lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
  11. def __init__(self, init_cached_hf_modules=False) -> None:
  12. if init_cached_hf_modules:
  13. from transformers.dynamic_module_utils import init_hf_modules
  14. init_hf_modules()
  15. self.worker = None
  16. # Since the compiled DAG runs a main execution
  17. # in a different thread that calls cuda.set_device.
  18. # The flag indicates is set_device is called on
  19. # that thread.
  20. self.compiled_dag_cuda_device_set = False
  21. def init_worker(self, worker_init_fn):
  22. self.worker = worker_init_fn()
  23. def __getattr__(self, name):
  24. return getattr(self.worker, name)
  25. def execute_method(self, method, *args, **kwargs):
  26. try:
  27. executor = getattr(self, method)
  28. return executor(*args, **kwargs)
  29. except Exception as e:
  30. # exceptions in ray worker may cause deadlock
  31. # print the error and inform the user to solve the error
  32. msg = (f"Error executing method {method}. "
  33. "This might cause deadlock in distributed execution.")
  34. logger.exception(msg)
  35. raise e
  36. def get_node_ip(self) -> str:
  37. return get_ip()
  38. def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
  39. node_id = ray.get_runtime_context().get_node_id()
  40. gpu_ids = ray.get_gpu_ids()
  41. return node_id, gpu_ids
  42. def set_cuda_visible_devices(self, device_ids) -> None:
  43. set_cuda_visible_devices(device_ids)
  44. def execute_model_compiled_dag_remote(self, ignored):
  45. """Used only when compiled DAG is enabled."""
  46. import torch
  47. if not self.compiled_dag_cuda_device_set:
  48. torch.cuda.set_device(self.worker.device)
  49. self.compiled_dag_cuda_device_set = True
  50. output = self.worker.execute_model()
  51. output = pickle.dumps(output)
  52. return output
  53. except ImportError as e:
  54. logger.warning(f"Failed to import Ray with {e!r}. "
  55. "For distributed inference, please install Ray with "
  56. "`pip install ray`.")
  57. ray = None
  58. RayWorkerAphrodite = None
  59. def initialize_ray_cluster(
  60. parallel_config: ParallelConfig,
  61. ray_address: Optional[str] = None,
  62. ):
  63. """Initialize the distributed cluster with Ray.
  64. it will connect to the Ray cluster and create a placement group
  65. for the workers, which includes the specification of the resources
  66. for each distributed worker.
  67. Args:
  68. parallel_config: The configurations for parallel execution.
  69. ray_address: The address of the Ray cluster. If None, uses
  70. the default Ray cluster address.
  71. """
  72. if ray is None:
  73. raise ImportError(
  74. "Ray is not installed. Please install Ray to use distributed "
  75. "serving.")
  76. # Connect to a ray cluster.
  77. if is_hip():
  78. ray.init(address=ray_address,
  79. ignore_reinit_error=True,
  80. num_gpus=parallel_config.world_size)
  81. else:
  82. ray.init(address=ray_address, ignore_reinit_error=True)
  83. if parallel_config.placement_group:
  84. # Placement group is already set.
  85. return
  86. # Create placement group for worker processes
  87. current_placement_group = ray.util.get_current_placement_group()
  88. if current_placement_group:
  89. # We are in a placement group
  90. bundles = current_placement_group.bundle_specs
  91. # Verify that we can use the placement group.
  92. gpu_bundles = 0
  93. for bundle in bundles:
  94. bundle_gpus = bundle.get("GPU", 0)
  95. if bundle_gpus > 1:
  96. raise ValueError(
  97. "Placement group bundle cannot have more than 1 GPU.")
  98. if bundle_gpus:
  99. gpu_bundles += 1
  100. if parallel_config.world_size > gpu_bundles:
  101. raise ValueError(
  102. "The number of required GPUs exceeds the total number of "
  103. "available GPUs in the placement group.")
  104. else:
  105. num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
  106. if parallel_config.world_size > num_gpus_in_cluster:
  107. raise ValueError(
  108. "The number of required GPUs exceeds the total number of "
  109. "available GPUs in the cluster.")
  110. # Create a new placement group
  111. placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
  112. current_placement_group = ray.util.placement_group(
  113. placement_group_specs)
  114. # Wait until PG is ready - this will block until all
  115. # requested resources are available, and will timeout
  116. # if they cannot be provisioned.
  117. ray.get(current_placement_group.ready(), timeout=1800)
  118. # Set the placement group in the parallel config
  119. parallel_config.placement_group = current_placement_group