parallel_state.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. """Tensor and pipeline parallel groups."""
  7. import contextlib
  8. import os
  9. from typing import Optional
  10. import torch
  11. from loguru import logger
  12. # Tensor model parallel group that the current rank belongs to.
  13. _TP_DEVICE_GROUP = None
  14. _TP_CPU_GROUP = None
  15. # Pipeline model parallel group that the current rank belongs to.
  16. _PIPELINE_MODEL_PARALLEL_GROUP = None
  17. # when people blindly call `torch.distributed.all_reduce` etc,
  18. # it will use this group. It is initialized with the `backend`
  19. # parameter of `init_distributed_environment` below.
  20. # Essentially, this is `torch.distributed.group.WORLD`.
  21. # We leave a line here to note that this is device-specific.
  22. # Note that this variable is not safe to use, because when users
  23. # call `init_distributed_environment` first, and then destroy
  24. # the process group themselves, this variable will keep a reference to the
  25. # destroyed process group, which is not useful.
  26. _DEVICE_WORLD_GROUP = None
  27. # duing `init_distributed_environment`, we will also initialize a
  28. # group with `gloo` backend, to allow direct coordination between
  29. # processes through the CPU.
  30. _CPU_WORLD_GROUP = None
  31. # In summary, after calling `init_distributed_environment`, we will
  32. # always have two groups: one for device-specific (and is the default)
  33. # and one for CPU. All processes will be part of both groups.
  34. # A list of global ranks for each pipeline group to ease calculation of the
  35. # source rank when broadcasting from the first or last pipeline stage.
  36. _PIPELINE_GLOBAL_RANKS = None
  37. _LOCAL_RANK = -1
  38. def get_local_rank():
  39. global _LOCAL_RANK
  40. return _LOCAL_RANK
  41. def init_distributed_environment(
  42. world_size: int = -1,
  43. rank: int = -1,
  44. distributed_init_method: str = "env://",
  45. local_rank: int = -1,
  46. backend: str = "nccl",
  47. ):
  48. logger.debug(f"{world_size=} {rank=} {local_rank=} "
  49. f"{distributed_init_method=} {backend=}")
  50. if not torch.distributed.is_initialized():
  51. assert distributed_init_method is not None, (
  52. "distributed_init_method must be provided when initializing "
  53. "distributed environment")
  54. # this backend is used for WORLD
  55. torch.distributed.init_process_group(
  56. backend=backend,
  57. init_method=distributed_init_method,
  58. world_size=world_size,
  59. rank=rank)
  60. global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
  61. _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
  62. ranks = list(range(torch.distributed.get_world_size()))
  63. _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
  64. backend="gloo")
  65. if local_rank == -1 and distributed_init_method == "env://":
  66. # Get local rank from environment variable.
  67. local_rank = int(os.environ.get["LOCAL_RANK"])
  68. global _LOCAL_RANK
  69. _LOCAL_RANK = local_rank
  70. def initialize_model_parallel(
  71. tensor_model_parallel_size: int = 1,
  72. pipeline_model_parallel_size: int = 1,
  73. backend: Optional[str] = None,
  74. ) -> None:
  75. """
  76. Initialize model parallel groups.
  77. Arguments:
  78. tensor_model_parallel_size: number of GPUs used for tensor model
  79. parallelism.
  80. pipeline_model_parallel_size: number of GPUs used for pipeline model
  81. parallelism.
  82. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  83. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  84. the model pipeline. The present function will
  85. create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
  86. 4 tensor model-parallel groups:
  87. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  88. 2 pipeline model-parallel groups:
  89. [g0, g2, g4, g6], [g1, g3, g5, g7]
  90. Note that for efficiency, the caller should make sure adjacent ranks
  91. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  92. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  93. ranks 8 to 15 belong to the second box.
  94. """
  95. # Get world size and rank. Ensure some consistencies.
  96. assert torch.distributed.is_initialized()
  97. world_size: int = torch.distributed.get_world_size()
  98. # get the backend of _DEVICE_WORLD_GROUP
  99. backend = backend or torch.distributed.get_backend()
  100. if (world_size !=
  101. tensor_model_parallel_size * pipeline_model_parallel_size):
  102. raise RuntimeError(
  103. f"world_size ({world_size}) is not equal to "
  104. f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
  105. f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
  106. num_tensor_model_parallel_groups: int = (world_size //
  107. tensor_model_parallel_size)
  108. num_pipeline_model_parallel_groups: int = (world_size //
  109. pipeline_model_parallel_size)
  110. rank = torch.distributed.get_rank()
  111. # Build the tensor model-parallel groups.
  112. global _TP_DEVICE_GROUP, _TP_CPU_GROUP
  113. assert _TP_DEVICE_GROUP is None, (
  114. "tensor model parallel group is already initialized")
  115. for i in range(num_tensor_model_parallel_groups):
  116. ranks = range(i * tensor_model_parallel_size,
  117. (i + 1) * tensor_model_parallel_size)
  118. group = torch.distributed.new_group(ranks, backend=backend)
  119. cpu_group = torch.distributed.new_group(ranks, backend="gloo")
  120. if rank in ranks:
  121. _TP_DEVICE_GROUP = group
  122. _TP_CPU_GROUP = cpu_group
  123. # Build the pipeline model-parallel groups.
  124. global _PIPELINE_MODEL_PARALLEL_GROUP
  125. global _PIPELINE_GLOBAL_RANKS
  126. assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
  127. "pipeline model parallel group is already initialized")
  128. for i in range(num_pipeline_model_parallel_groups):
  129. ranks = range(i, world_size, num_pipeline_model_parallel_groups)
  130. group = torch.distributed.new_group(ranks, backend=backend)
  131. if rank in ranks:
  132. _PIPELINE_MODEL_PARALLEL_GROUP = group
  133. _PIPELINE_GLOBAL_RANKS = ranks
  134. def ensure_model_parallel_initialized(
  135. tensor_model_parallel_size: int,
  136. pipeline_model_parallel_size: int,
  137. backend: Optional[str] = None,
  138. ) -> None:
  139. """Helper to initialize model parallel groups if they are not initialized,
  140. or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
  141. values if the model parallel groups are initialized.
  142. """
  143. # get the backend of _DEVICE_WORLD_GROUP
  144. backend = backend or torch.distributed.get_backend()
  145. if not model_parallel_is_initialized():
  146. initialize_model_parallel(tensor_model_parallel_size,
  147. pipeline_model_parallel_size, backend)
  148. return
  149. assert (
  150. get_tensor_model_parallel_world_size() == tensor_model_parallel_size
  151. ), ("tensor parallel group already initialized, but of unexpected size: "
  152. f"{get_tensor_model_parallel_world_size()=} vs. "
  153. f"{tensor_model_parallel_size=}")
  154. assert (get_pipeline_model_parallel_world_size(
  155. ) == pipeline_model_parallel_size), (
  156. "pipeline parallel group already initialized, but of unexpected size: "
  157. f"{get_pipeline_model_parallel_world_size()=} vs. "
  158. f"{pipeline_model_parallel_size=}")
  159. def model_parallel_is_initialized():
  160. """Check if tensor and pipeline parallel groups are initialized."""
  161. return (_TP_DEVICE_GROUP is not None
  162. and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
  163. def get_cpu_world_group():
  164. """Get the CPU world group."""
  165. assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
  166. return _CPU_WORLD_GROUP
  167. def get_tensor_model_parallel_group():
  168. """Get the tensor model parallel group the caller rank belongs to."""
  169. assert _TP_DEVICE_GROUP is not None, (
  170. "tensor model parallel group is not initialized")
  171. return _TP_DEVICE_GROUP
  172. def get_tensor_model_parallel_cpu_group():
  173. """Get the tensor model parallel cpu group the caller rank belongs to."""
  174. assert _TP_CPU_GROUP is not None, (
  175. "tensor model parallel cpu group is not initialized")
  176. return _TP_CPU_GROUP
  177. def get_pipeline_model_parallel_group():
  178. """Get the pipeline model parallel group the caller rank belongs to."""
  179. assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
  180. "pipeline model parallel group is not initialized")
  181. return _PIPELINE_MODEL_PARALLEL_GROUP
  182. def get_tensor_model_parallel_world_size():
  183. """Return world size for the tensor model parallel group."""
  184. return torch.distributed.get_world_size(
  185. group=get_tensor_model_parallel_group())
  186. def get_pipeline_model_parallel_world_size():
  187. """Return world size for the pipeline model parallel group."""
  188. return torch.distributed.get_world_size(
  189. group=get_pipeline_model_parallel_group())
  190. def get_tensor_model_parallel_rank():
  191. """Return my rank for the tensor model parallel group."""
  192. return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
  193. def get_pipeline_model_parallel_rank():
  194. """Return my rank for the pipeline model parallel group."""
  195. return torch.distributed.get_rank(
  196. group=get_pipeline_model_parallel_group())
  197. def get_tensor_model_parallel_src_rank():
  198. """Calculate the global rank corresponding to the first local rank
  199. in the tensor model parallel group."""
  200. global_rank = torch.distributed.get_rank()
  201. local_world_size = get_tensor_model_parallel_world_size()
  202. return (global_rank // local_world_size) * local_world_size
  203. def get_pipeline_model_parallel_first_rank():
  204. """Return the global rank of the first process in the pipeline for the
  205. current tensor parallel group"""
  206. assert _PIPELINE_GLOBAL_RANKS is not None, (
  207. "Pipeline parallel group is not initialized")
  208. return _PIPELINE_GLOBAL_RANKS[0]
  209. def get_pipeline_model_parallel_last_rank():
  210. """Return the global rank of the last process in the pipeline for the
  211. current tensor parallel group"""
  212. assert _PIPELINE_GLOBAL_RANKS is not None, (
  213. "Pipeline parallel group is not initialized")
  214. last_rank_local = get_pipeline_model_parallel_world_size() - 1
  215. return _PIPELINE_GLOBAL_RANKS[last_rank_local]
  216. def get_pipeline_model_parallel_next_rank():
  217. """Return the global rank that follows the caller in the pipeline"""
  218. assert _PIPELINE_GLOBAL_RANKS is not None, (
  219. "Pipeline parallel group is not initialized")
  220. rank_in_pipeline = get_pipeline_model_parallel_rank()
  221. world_size = get_pipeline_model_parallel_world_size()
  222. return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
  223. def get_pipeline_model_parallel_prev_rank():
  224. """Return the global rank that precedes the caller in the pipeline"""
  225. assert _PIPELINE_GLOBAL_RANKS is not None, (
  226. "Pipeline parallel group is not initialized")
  227. rank_in_pipeline = get_pipeline_model_parallel_rank()
  228. world_size = get_pipeline_model_parallel_world_size()
  229. return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
  230. def destroy_model_parallel():
  231. """Set the groups to none and destroy them."""
  232. global _TP_DEVICE_GROUP
  233. if _TP_DEVICE_GROUP:
  234. torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
  235. _TP_DEVICE_GROUP = None
  236. global _TP_CPU_GROUP
  237. if _TP_CPU_GROUP:
  238. torch.distributed.destroy_process_group(_TP_CPU_GROUP)
  239. _TP_CPU_GROUP = None
  240. global _PIPELINE_MODEL_PARALLEL_GROUP
  241. if _PIPELINE_MODEL_PARALLEL_GROUP:
  242. torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
  243. _PIPELINE_MODEL_PARALLEL_GROUP = None
  244. global _PIPELINE_GLOBAL_RANKS
  245. _PIPELINE_GLOBAL_RANKS = None
  246. from aphrodite.distributed.device_communicators import pynccl_utils
  247. # Destroy the pynccl states if any.
  248. pynccl_utils.destroy_process_group()
  249. # Whether to use pynccl for nccl all reduce.
  250. # We use pynccl for all reduce when using CUDA graph, because torch.distributed
  251. # is not well supported by CUDA graph.
  252. _ENABLE_PYNCCL_FOR_ALL_REDUCE = False
  253. @contextlib.contextmanager
  254. def with_pynccl_for_all_reduce():
  255. """use Pynccl instead of torch.distributed for all reduce"""
  256. from aphrodite.distributed.device_communicators import pynccl_utils
  257. tp_size = get_tensor_model_parallel_world_size()
  258. if tp_size == 1:
  259. # No-op.
  260. # NOTE: We don't initialize Pynccl when tp_size is 1.
  261. yield
  262. else:
  263. global _ENABLE_PYNCCL_FOR_ALL_REDUCE
  264. old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
  265. _ENABLE_PYNCCL_FOR_ALL_REDUCE = True
  266. stream = torch.cuda.current_stream()
  267. with pynccl_utils.set_pynccl_stream(stream):
  268. yield
  269. _ENABLE_PYNCCL_FOR_ALL_REDUCE = old
  270. def is_pynccl_enabled_for_all_reduce():
  271. """check if Pynccl is enabled for all reduce"""
  272. global _ENABLE_PYNCCL_FOR_ALL_REDUCE
  273. return _ENABLE_PYNCCL_FOR_ALL_REDUCE