1
0

parallel_state.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. """Tensor and pipeline parallel groups."""
  7. import contextlib
  8. import os
  9. from typing import Optional
  10. import torch
  11. from loguru import logger
  12. # Tensor model parallel group that the current rank belongs to.
  13. _TENSOR_MODEL_PARALLEL_GROUP = None
  14. # Pipeline model parallel group that the current rank belongs to.
  15. _PIPELINE_MODEL_PARALLEL_GROUP = None
  16. # when people blindly call `torch.distributed.all_reduce` etc,
  17. # it will use this group. It is initialized with the `backend`
  18. # parameter of `init_distributed_environment` below.
  19. # Essentially, this is `torch.distributed.group.WORLD`.
  20. # We leave a line here to note that this is device-specific.
  21. # Note that this variable is not safe to use, because when users
  22. # call `init_distributed_environment` first, and then destroy
  23. # the process group themselves, this variable will keep a reference to the
  24. # destroyed process group, which is not useful.
  25. _DEVICE_WORLD_GROUP = None
  26. # duing `init_distributed_environment`, we will also initialize a
  27. # group with `gloo` backend, to allow direct coordination between
  28. # processes through the CPU.
  29. _CPU_WORLD_GROUP = None
  30. # In summary, after calling `init_distributed_environment`, we will
  31. # always have two groups: one for device-specific (and is the default)
  32. # and one for CPU. All processes will be part of both groups.
  33. # A list of global ranks for each pipeline group to ease calculation of the
  34. # source rank when broadcasting from the first or last pipeline stage.
  35. _PIPELINE_GLOBAL_RANKS = None
  36. _LOCAL_RANK = -1
  37. def get_local_rank():
  38. global _LOCAL_RANK
  39. return _LOCAL_RANK
  40. def init_distributed_environment(
  41. world_size: int = -1,
  42. rank: int = -1,
  43. distributed_init_method: str = "env://",
  44. local_rank: int = -1,
  45. backend: str = "nccl",
  46. ):
  47. logger.debug(f"{world_size=} {rank=} {local_rank=} "
  48. f"{distributed_init_method=} {backend=}")
  49. if not torch.distributed.is_initialized():
  50. assert distributed_init_method is not None, (
  51. "distributed_init_method must be provided when initializing "
  52. "distributed environment")
  53. # this backend is used for WORLD
  54. torch.distributed.init_process_group(
  55. backend=backend,
  56. init_method=distributed_init_method,
  57. world_size=world_size,
  58. rank=rank)
  59. global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
  60. _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
  61. ranks = list(range(torch.distributed.get_world_size()))
  62. _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
  63. backend="gloo")
  64. if local_rank == -1 and distributed_init_method == "env://":
  65. # Get local rank from environment variable.
  66. local_rank = int(os.environ.get["LOCAL_RANK"])
  67. global _LOCAL_RANK
  68. _LOCAL_RANK = local_rank
  69. def initialize_model_parallel(
  70. tensor_model_parallel_size: int = 1,
  71. pipeline_model_parallel_size: int = 1,
  72. backend: Optional[str] = None,
  73. ) -> None:
  74. """
  75. Initialize model parallel groups.
  76. Arguments:
  77. tensor_model_parallel_size: number of GPUs used for tensor model
  78. parallelism.
  79. pipeline_model_parallel_size: number of GPUs used for pipeline model
  80. parallelism.
  81. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  82. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  83. the model pipeline. The present function will
  84. create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
  85. 4 tensor model-parallel groups:
  86. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  87. 2 pipeline model-parallel groups:
  88. [g0, g2, g4, g6], [g1, g3, g5, g7]
  89. Note that for efficiency, the caller should make sure adjacent ranks
  90. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  91. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  92. ranks 8 to 15 belong to the second box.
  93. """
  94. # Get world size and rank. Ensure some consistencies.
  95. assert torch.distributed.is_initialized()
  96. world_size: int = torch.distributed.get_world_size()
  97. # get the backend of _DEVICE_WORLD_GROUP
  98. backend = backend or torch.distributed.get_backend()
  99. if (world_size !=
  100. tensor_model_parallel_size * pipeline_model_parallel_size):
  101. raise RuntimeError(
  102. f"world_size ({world_size}) is not equal to "
  103. f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
  104. f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
  105. num_tensor_model_parallel_groups: int = (world_size //
  106. tensor_model_parallel_size)
  107. num_pipeline_model_parallel_groups: int = (world_size //
  108. pipeline_model_parallel_size)
  109. rank = torch.distributed.get_rank()
  110. # Build the tensor model-parallel groups.
  111. global _TENSOR_MODEL_PARALLEL_GROUP
  112. assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
  113. "tensor model parallel group is already initialized")
  114. for i in range(num_tensor_model_parallel_groups):
  115. ranks = range(i * tensor_model_parallel_size,
  116. (i + 1) * tensor_model_parallel_size)
  117. group = torch.distributed.new_group(ranks, backend=backend)
  118. if rank in ranks:
  119. _TENSOR_MODEL_PARALLEL_GROUP = group
  120. # Build the pipeline model-parallel groups.
  121. global _PIPELINE_MODEL_PARALLEL_GROUP
  122. global _PIPELINE_GLOBAL_RANKS
  123. assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
  124. "pipeline model parallel group is already initialized")
  125. for i in range(num_pipeline_model_parallel_groups):
  126. ranks = range(i, world_size, num_pipeline_model_parallel_groups)
  127. group = torch.distributed.new_group(ranks, backend=backend)
  128. if rank in ranks:
  129. _PIPELINE_MODEL_PARALLEL_GROUP = group
  130. _PIPELINE_GLOBAL_RANKS = ranks
  131. def ensure_model_parallel_initialized(
  132. tensor_model_parallel_size: int,
  133. pipeline_model_parallel_size: int,
  134. backend: Optional[str] = None,
  135. ) -> None:
  136. """Helper to initialize model parallel groups if they are not initialized,
  137. or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
  138. values if the model parallel groups are initialized.
  139. """
  140. # get the backend of _DEVICE_WORLD_GROUP
  141. backend = backend or torch.distributed.get_backend()
  142. if not model_parallel_is_initialized():
  143. initialize_model_parallel(tensor_model_parallel_size,
  144. pipeline_model_parallel_size, backend)
  145. return
  146. assert (
  147. get_tensor_model_parallel_world_size() == tensor_model_parallel_size
  148. ), ("tensor parallel group already initialized, but of unexpected size: "
  149. f"{get_tensor_model_parallel_world_size()=} vs. "
  150. f"{tensor_model_parallel_size=}")
  151. assert (get_pipeline_model_parallel_world_size(
  152. ) == pipeline_model_parallel_size), (
  153. "pipeline parallel group already initialized, but of unexpected size: "
  154. f"{get_pipeline_model_parallel_world_size()=} vs. "
  155. f"{pipeline_model_parallel_size=}")
  156. def model_parallel_is_initialized():
  157. """Check if tensor and pipeline parallel groups are initialized."""
  158. return (_TENSOR_MODEL_PARALLEL_GROUP is not None
  159. and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
  160. def get_cpu_world_group():
  161. """Get the CPU world group."""
  162. assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
  163. return _CPU_WORLD_GROUP
  164. def get_tensor_model_parallel_group():
  165. """Get the tensor model parallel group the caller rank belongs to."""
  166. assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
  167. "tenosr model parallel group is not initialized")
  168. return _TENSOR_MODEL_PARALLEL_GROUP
  169. def get_pipeline_model_parallel_group():
  170. """Get the pipeline model parallel group the caller rank belongs to."""
  171. assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
  172. "pipeline model parallel group is not initialized")
  173. return _PIPELINE_MODEL_PARALLEL_GROUP
  174. def get_tensor_model_parallel_world_size():
  175. """Return world size for the tensor model parallel group."""
  176. return torch.distributed.get_world_size(
  177. group=get_tensor_model_parallel_group())
  178. def get_pipeline_model_parallel_world_size():
  179. """Return world size for the pipeline model parallel group."""
  180. return torch.distributed.get_world_size(
  181. group=get_pipeline_model_parallel_group())
  182. def get_tensor_model_parallel_rank():
  183. """Return my rank for the tensor model parallel group."""
  184. return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
  185. def get_pipeline_model_parallel_rank():
  186. """Return my rank for the pipeline model parallel group."""
  187. return torch.distributed.get_rank(
  188. group=get_pipeline_model_parallel_group())
  189. def get_tensor_model_parallel_src_rank():
  190. """Calculate the global rank corresponding to the first local rank
  191. in the tensor model parallel group."""
  192. global_rank = torch.distributed.get_rank()
  193. local_world_size = get_tensor_model_parallel_world_size()
  194. return (global_rank // local_world_size) * local_world_size
  195. def get_pipeline_model_parallel_first_rank():
  196. """Return the global rank of the first process in the pipeline for the
  197. current tensor parallel group"""
  198. assert _PIPELINE_GLOBAL_RANKS is not None, (
  199. "Pipeline parallel group is not initialized")
  200. return _PIPELINE_GLOBAL_RANKS[0]
  201. def get_pipeline_model_parallel_last_rank():
  202. """Return the global rank of the last process in the pipeline for the
  203. current tensor parallel group"""
  204. assert _PIPELINE_GLOBAL_RANKS is not None, (
  205. "Pipeline parallel group is not initialized")
  206. last_rank_local = get_pipeline_model_parallel_world_size() - 1
  207. return _PIPELINE_GLOBAL_RANKS[last_rank_local]
  208. def get_pipeline_model_parallel_next_rank():
  209. """Return the global rank that follows the caller in the pipeline"""
  210. assert _PIPELINE_GLOBAL_RANKS is not None, (
  211. "Pipeline parallel group is not initialized")
  212. rank_in_pipeline = get_pipeline_model_parallel_rank()
  213. world_size = get_pipeline_model_parallel_world_size()
  214. return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
  215. def get_pipeline_model_parallel_prev_rank():
  216. """Return the global rank that precedes the caller in the pipeline"""
  217. assert _PIPELINE_GLOBAL_RANKS is not None, (
  218. "Pipeline parallel group is not initialized")
  219. rank_in_pipeline = get_pipeline_model_parallel_rank()
  220. world_size = get_pipeline_model_parallel_world_size()
  221. return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
  222. def destroy_model_parallel():
  223. """Set the groups to none and destroy them."""
  224. global _TENSOR_MODEL_PARALLEL_GROUP
  225. if _TENSOR_MODEL_PARALLEL_GROUP:
  226. torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
  227. _TENSOR_MODEL_PARALLEL_GROUP = None
  228. global _PIPELINE_MODEL_PARALLEL_GROUP
  229. if _PIPELINE_MODEL_PARALLEL_GROUP:
  230. torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
  231. _PIPELINE_MODEL_PARALLEL_GROUP = None
  232. global _PIPELINE_GLOBAL_RANKS
  233. _PIPELINE_GLOBAL_RANKS = None
  234. from aphrodite.distributed.device_communicators import pynccl_utils
  235. # Destroy the pynccl states if any.
  236. pynccl_utils.destroy_process_group()
  237. # Whether to use pynccl for nccl all reduce.
  238. # We use pynccl for all reduce when using CUDA graph, because torch.distributed
  239. # is not well supported by CUDA graph.
  240. _ENABLE_PYNCCL_FOR_ALL_REDUCE = False
  241. @contextlib.contextmanager
  242. def with_pynccl_for_all_reduce():
  243. """use Pynccl instead of torch.distributed for all reduce"""
  244. from aphrodite.distributed.device_communicators import pynccl_utils
  245. tp_size = get_tensor_model_parallel_world_size()
  246. if tp_size == 1:
  247. # No-op.
  248. # NOTE: We don't initialize Pynccl when tp_size is 1.
  249. yield
  250. else:
  251. global _ENABLE_PYNCCL_FOR_ALL_REDUCE
  252. old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
  253. _ENABLE_PYNCCL_FOR_ALL_REDUCE = True
  254. stream = torch.cuda.current_stream()
  255. with pynccl_utils.set_pynccl_stream(stream):
  256. yield
  257. _ENABLE_PYNCCL_FOR_ALL_REDUCE = old
  258. def is_pynccl_enabled_for_all_reduce():
  259. """check if Pynccl is enabled for all reduce"""
  260. global _ENABLE_PYNCCL_FOR_ALL_REDUCE
  261. return _ENABLE_PYNCCL_FOR_ALL_REDUCE