parallel_state.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. """Tensor and pipeline parallel groups."""
  7. import contextlib
  8. import torch
  9. from aphrodite.modeling.megatron import cupy_utils
  10. # Tensor model parallel group that the current rank belongs to.
  11. _TENSOR_MODEL_PARALLEL_GROUP = None
  12. # Pipeline model parallel group that the current rank belongs to.
  13. _PIPELINE_MODEL_PARALLEL_GROUP = None
  14. # A list of global ranks for each pipeline group to ease calculation of the
  15. # source rank when broadcasting from the first or last pipeline stage.
  16. _PIPELINE_GLOBAL_RANKS = None
  17. def initialize_model_parallel(
  18. tensor_model_parallel_size: int = 1,
  19. pipeline_model_parallel_size: int = 1,
  20. ) -> None:
  21. """
  22. Initialize model parallel groups.
  23. Arguments:
  24. tensor_model_parallel_size: number of GPUs used for tensor model
  25. parallelism.
  26. pipeline_model_parallel_size: number of GPUs used for pipeline model
  27. parallelism.
  28. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  29. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  30. the model pipeline. The present function will
  31. create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
  32. 4 tensor model-parallel groups:
  33. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  34. 2 pipeline model-parallel groups:
  35. [g0, g2, g4, g6], [g1, g3, g5, g7]
  36. Note that for efficiency, the caller should make sure adjacent ranks
  37. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  38. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  39. ranks 8 to 15 belong to the second box.
  40. """
  41. # Get world size and rank. Ensure some consistencies.
  42. assert torch.distributed.is_initialized()
  43. world_size: int = torch.distributed.get_world_size()
  44. if (world_size !=
  45. tensor_model_parallel_size * pipeline_model_parallel_size):
  46. raise RuntimeError(
  47. f"world_size ({world_size}) is not equal to "
  48. f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
  49. f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
  50. num_tensor_model_parallel_groups: int = (world_size //
  51. tensor_model_parallel_size)
  52. num_pipeline_model_parallel_groups: int = (world_size //
  53. pipeline_model_parallel_size)
  54. rank = torch.distributed.get_rank()
  55. # Build the tensor model-parallel groups.
  56. global _TENSOR_MODEL_PARALLEL_GROUP
  57. assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
  58. "tensor model parallel group is already initialized")
  59. for i in range(num_tensor_model_parallel_groups):
  60. ranks = range(i * tensor_model_parallel_size,
  61. (i + 1) * tensor_model_parallel_size)
  62. group = torch.distributed.new_group(ranks)
  63. if rank in ranks:
  64. _TENSOR_MODEL_PARALLEL_GROUP = group
  65. # Build the pipeline model-parallel groups.
  66. global _PIPELINE_MODEL_PARALLEL_GROUP
  67. global _PIPELINE_GLOBAL_RANKS
  68. assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
  69. "pipeline model parallel group is already initialized")
  70. for i in range(num_pipeline_model_parallel_groups):
  71. ranks = range(i, world_size, num_pipeline_model_parallel_groups)
  72. group = torch.distributed.new_group(ranks)
  73. if rank in ranks:
  74. _PIPELINE_MODEL_PARALLEL_GROUP = group
  75. _PIPELINE_GLOBAL_RANKS = ranks
  76. def ensure_model_parallel_initialized(
  77. tensor_model_parallel_size: int,
  78. pipeline_model_parallel_size: int,
  79. ) -> None:
  80. """Helper to initialize model parallel groups if they are not initialized,
  81. or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
  82. values if the model parallel groups are initialized.
  83. """
  84. if not model_parallel_is_initialized():
  85. initialize_model_parallel(tensor_model_parallel_size,
  86. pipeline_model_parallel_size)
  87. return
  88. assert (
  89. get_tensor_model_parallel_world_size() == tensor_model_parallel_size
  90. ), ("tensor parallel group already initialized, but of unexpected size: "
  91. f"{get_tensor_model_parallel_world_size()=} vs. "
  92. f"{tensor_model_parallel_size=}")
  93. assert (get_pipeline_model_parallel_world_size(
  94. ) == pipeline_model_parallel_size), (
  95. "pipeline parallel group already initialized, but of unexpected size: "
  96. f"{get_pipeline_model_parallel_world_size()=} vs. "
  97. f"{pipeline_model_parallel_size=}")
  98. def model_parallel_is_initialized():
  99. """Check if tensor and pipeline parallel groups are initialized."""
  100. return (_TENSOR_MODEL_PARALLEL_GROUP is not None
  101. and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
  102. def get_tensor_model_parallel_group():
  103. """Get the tensor model parallel group the caller rank belongs to."""
  104. assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
  105. "tenosr model parallel group is not initialized")
  106. return _TENSOR_MODEL_PARALLEL_GROUP
  107. def get_pipeline_model_parallel_group():
  108. """Get the pipeline model parallel group the caller rank belongs to."""
  109. assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
  110. "pipeline model parallel group is not initialized")
  111. return _PIPELINE_MODEL_PARALLEL_GROUP
  112. def get_tensor_model_parallel_world_size():
  113. """Return world size for the tensor model parallel group."""
  114. return torch.distributed.get_world_size(
  115. group=get_tensor_model_parallel_group())
  116. def get_pipeline_model_parallel_world_size():
  117. """Return world size for the pipeline model parallel group."""
  118. return torch.distributed.get_world_size(
  119. group=get_pipeline_model_parallel_group())
  120. def get_tensor_model_parallel_rank():
  121. """Return my rank for the tensor model parallel group."""
  122. return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
  123. def get_pipeline_model_parallel_rank():
  124. """Return my rank for the pipeline model parallel group."""
  125. return torch.distributed.get_rank(
  126. group=get_pipeline_model_parallel_group())
  127. def get_tensor_model_parallel_src_rank():
  128. """Calculate the global rank corresponding to the first local rank
  129. in the tensor model parallel group."""
  130. global_rank = torch.distributed.get_rank()
  131. local_world_size = get_tensor_model_parallel_world_size()
  132. return (global_rank // local_world_size) * local_world_size
  133. def get_pipeline_model_parallel_first_rank():
  134. """Return the global rank of the first process in the pipeline for the
  135. current tensor parallel group"""
  136. assert _PIPELINE_GLOBAL_RANKS is not None, (
  137. "Pipeline parallel group is not initialized")
  138. return _PIPELINE_GLOBAL_RANKS[0]
  139. def get_pipeline_model_parallel_last_rank():
  140. """Return the global rank of the last process in the pipeline for the
  141. current tensor parallel group"""
  142. assert _PIPELINE_GLOBAL_RANKS is not None, (
  143. "Pipeline parallel group is not initialized")
  144. last_rank_local = get_pipeline_model_parallel_world_size() - 1
  145. return _PIPELINE_GLOBAL_RANKS[last_rank_local]
  146. def get_pipeline_model_parallel_next_rank():
  147. """Return the global rank that follows the caller in the pipeline"""
  148. assert _PIPELINE_GLOBAL_RANKS is not None, (
  149. "Pipeline parallel group is not initialized")
  150. rank_in_pipeline = get_pipeline_model_parallel_rank()
  151. world_size = get_pipeline_model_parallel_world_size()
  152. return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
  153. def get_pipeline_model_parallel_prev_rank():
  154. """Return the global rank that precedes the caller in the pipeline"""
  155. assert _PIPELINE_GLOBAL_RANKS is not None, (
  156. "Pipeline parallel group is not initialized")
  157. rank_in_pipeline = get_pipeline_model_parallel_rank()
  158. world_size = get_pipeline_model_parallel_world_size()
  159. return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
  160. def destroy_model_parallel():
  161. """Set the groups to none and destroy them."""
  162. global _TENSOR_MODEL_PARALLEL_GROUP
  163. if _TENSOR_MODEL_PARALLEL_GROUP:
  164. torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
  165. _TENSOR_MODEL_PARALLEL_GROUP = None
  166. global _PIPELINE_MODEL_PARALLEL_GROUP
  167. if _PIPELINE_MODEL_PARALLEL_GROUP:
  168. torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
  169. _PIPELINE_MODEL_PARALLEL_GROUP = None
  170. global _PIPELINE_GLOBAL_RANKS
  171. _PIPELINE_GLOBAL_RANKS = None
  172. # Destroy the cupy states if any.
  173. cupy_utils.destroy_process_group()
  174. # Whether to use cupy for nccl all reduce.
  175. # We use cupy for all reduce when using CUDA graph, because torch.distributed
  176. # is not well supported by CUDA graph.
  177. _ENABLE_CUPY_FOR_ALL_REDUCE = False
  178. @contextlib.contextmanager
  179. def with_cupy_nccl_for_all_reduce():
  180. """use CuPy nccl instead of torch.distributed for all reduce"""
  181. tp_size = get_tensor_model_parallel_world_size()
  182. if tp_size == 1:
  183. # No-op.
  184. # NOTE: We don't initialize CuPy when tp_size is 1.
  185. yield
  186. else:
  187. global _ENABLE_CUPY_FOR_ALL_REDUCE
  188. old = _ENABLE_CUPY_FOR_ALL_REDUCE
  189. _ENABLE_CUPY_FOR_ALL_REDUCE = True
  190. stream = torch.cuda.current_stream()
  191. with cupy_utils.set_cupy_stream(stream):
  192. yield
  193. _ENABLE_CUPY_FOR_ALL_REDUCE = old
  194. def is_cupy_nccl_enabled_for_all_reduce():
  195. """check if CuPy nccl is enabled for all reduce"""
  196. global _ENABLE_CUPY_FOR_ALL_REDUCE
  197. return _ENABLE_CUPY_FOR_ALL_REDUCE