parallel_state.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. """Tensor and pipeline parallel groups."""
  7. import os
  8. from typing import List, Optional
  9. import torch
  10. from loguru import logger
  11. from torch.distributed import ProcessGroup
  12. _ENABLE_CUSTOM_ALL_REDUCE = True
  13. # Tensor model parallel group that the current rank belongs to.
  14. _TP_DEVICE_GROUP: Optional[ProcessGroup] = None
  15. _TP_CPU_GROUP: Optional[ProcessGroup] = None
  16. _TP_PYNCCL_COMMUNICATOR = None
  17. _TP_CA_COMMUNICATOR = None
  18. # Pipeline model parallel group that the current rank belongs to.
  19. _PP_DEVICE_GROUP: Optional[ProcessGroup] = None
  20. _PP_CPU_GROUP: Optional[ProcessGroup] = None
  21. _PP_PYNCCL_COMMUNICATOR = None
  22. # when people blindly call `torch.distributed.all_reduce` etc,
  23. # it will use this group. It is initialized with the `backend`
  24. # parameter of `init_distributed_environment` below.
  25. # Essentially, this is `torch.distributed.group.WORLD`.
  26. # We leave a line here to note that this is device-specific.
  27. # Note that this variable is not safe to use, because when users
  28. # call `init_distributed_environment` first, and then destroy
  29. # the process group themselves, this variable will keep a reference to the
  30. # destroyed process group, which is not useful.
  31. _DEVICE_WORLD_GROUP = None
  32. # duing `init_distributed_environment`, we will also initialize a
  33. # group with `gloo` backend, to allow direct coordination between
  34. # processes through the CPU.
  35. _CPU_WORLD_GROUP = None
  36. # In summary, after calling `init_distributed_environment`, we will
  37. # always have two groups: one for device-specific (and is the default)
  38. # and one for CPU. All processes will be part of both groups.
  39. # A list of global ranks for each pipeline group to ease calculation of the
  40. # source rank when broadcasting from the first or last pipeline stage.
  41. _PP_GLOBAL_RANKS: Optional[List[int]] = None
  42. _LOCAL_RANK = -1
  43. def set_custom_all_reduce(enable: bool):
  44. global _ENABLE_CUSTOM_ALL_REDUCE
  45. _ENABLE_CUSTOM_ALL_REDUCE = enable
  46. def get_pp_pynccl_communicator():
  47. global _PP_PYNCCL_COMMUNICATOR
  48. return _PP_PYNCCL_COMMUNICATOR
  49. def get_tp_pynccl_communicator():
  50. global _TP_PYNCCL_COMMUNICATOR
  51. return _TP_PYNCCL_COMMUNICATOR
  52. def get_tp_ca_communicator():
  53. global _TP_CA_COMMUNICATOR
  54. return _TP_CA_COMMUNICATOR
  55. def get_local_rank():
  56. global _LOCAL_RANK
  57. return _LOCAL_RANK
  58. def init_distributed_environment(
  59. world_size: int = -1,
  60. rank: int = -1,
  61. distributed_init_method: str = "env://",
  62. local_rank: int = -1,
  63. backend: str = "nccl",
  64. ):
  65. logger.debug(f"{world_size=} {rank=} {local_rank=} "
  66. f"{distributed_init_method=} {backend=}")
  67. if not torch.distributed.is_initialized():
  68. assert distributed_init_method is not None, (
  69. "distributed_init_method must be provided when initializing "
  70. "distributed environment")
  71. # this backend is used for WORLD
  72. torch.distributed.init_process_group(
  73. backend=backend,
  74. init_method=distributed_init_method,
  75. world_size=world_size,
  76. rank=rank)
  77. global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
  78. _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
  79. ranks = list(range(torch.distributed.get_world_size()))
  80. _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
  81. backend="gloo")
  82. if local_rank == -1:
  83. # local rank not set, this usually happens in single-node
  84. # setting, where we can use rank as local rank
  85. if distributed_init_method == "env://":
  86. local_rank = int(os.environ.get("LOCAL_RANK", rank))
  87. else:
  88. local_rank = rank
  89. global _LOCAL_RANK
  90. _LOCAL_RANK = local_rank
  91. # A small all_reduce for warmup.
  92. data = torch.zeros(1)
  93. if torch.cuda.is_available():
  94. data = data.to(device=f"cuda:{local_rank}")
  95. torch.distributed.all_reduce(data)
  96. if torch.cuda.is_available():
  97. torch.cuda.synchronize()
  98. del data
  99. def initialize_model_parallel(
  100. tensor_model_parallel_size: int = 1,
  101. pipeline_model_parallel_size: int = 1,
  102. backend: Optional[str] = None,
  103. ) -> None:
  104. """
  105. Initialize model parallel groups.
  106. Arguments:
  107. tensor_model_parallel_size: number of GPUs used for tensor model
  108. parallelism.
  109. pipeline_model_parallel_size: number of GPUs used for pipeline model
  110. parallelism.
  111. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  112. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  113. the model pipeline. The present function will
  114. create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
  115. 4 tensor model-parallel groups:
  116. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  117. 2 pipeline model-parallel groups:
  118. [g0, g2, g4, g6], [g1, g3, g5, g7]
  119. Note that for efficiency, the caller should make sure adjacent ranks
  120. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  121. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  122. ranks 8 to 15 belong to the second box.
  123. """
  124. # Get world size and rank. Ensure some consistencies.
  125. assert torch.distributed.is_initialized()
  126. world_size: int = torch.distributed.get_world_size()
  127. # get the backend of _DEVICE_WORLD_GROUP
  128. backend = backend or torch.distributed.get_backend()
  129. if (world_size !=
  130. tensor_model_parallel_size * pipeline_model_parallel_size):
  131. raise RuntimeError(
  132. f"world_size ({world_size}) is not equal to "
  133. f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
  134. f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
  135. num_tensor_model_parallel_groups: int = (world_size //
  136. tensor_model_parallel_size)
  137. num_pipeline_model_parallel_groups: int = (world_size //
  138. pipeline_model_parallel_size)
  139. rank = torch.distributed.get_rank()
  140. # Build the tensor model-parallel groups.
  141. global _TP_DEVICE_GROUP, _TP_CPU_GROUP
  142. global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR
  143. assert _TP_DEVICE_GROUP is None, (
  144. "tensor model parallel group is already initialized")
  145. for i in range(num_tensor_model_parallel_groups):
  146. ranks = list(
  147. range(i * tensor_model_parallel_size,
  148. (i + 1) * tensor_model_parallel_size))
  149. group = torch.distributed.new_group(ranks, backend=backend)
  150. cpu_group = torch.distributed.new_group(ranks, backend="gloo")
  151. if rank in ranks:
  152. _TP_DEVICE_GROUP = group
  153. _TP_CPU_GROUP = cpu_group
  154. from aphrodite.distributed.device_communicators.pynccl import \
  155. PyNcclCommunicator
  156. if tensor_model_parallel_size > 1:
  157. _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
  158. group=_TP_CPU_GROUP,
  159. device=_LOCAL_RANK,
  160. )
  161. # Initialize a custom fast all-reduce implementation.
  162. if _ENABLE_CUSTOM_ALL_REDUCE:
  163. from aphrodite.distributed.device_communicators.custom_all_reduce \
  164. import CustomAllreduce
  165. _TP_CA_COMMUNICATOR = CustomAllreduce(
  166. group=_TP_CPU_GROUP,
  167. device=_LOCAL_RANK,
  168. )
  169. # Build the pipeline model-parallel groups.
  170. global _PP_DEVICE_GROUP, _PP_CPU_GROUP
  171. global _PP_PYNCCL_COMMUNICATOR
  172. global _PP_GLOBAL_RANKS
  173. assert _PP_DEVICE_GROUP is None, (
  174. "pipeline model parallel group is already initialized")
  175. for i in range(num_pipeline_model_parallel_groups):
  176. ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
  177. group = torch.distributed.new_group(ranks, backend=backend)
  178. cpu_group = torch.distributed.new_group(ranks, backend="gloo")
  179. if rank in ranks:
  180. _PP_DEVICE_GROUP = group
  181. _PP_CPU_GROUP = cpu_group
  182. _PP_GLOBAL_RANKS = ranks
  183. if pipeline_model_parallel_size > 1:
  184. _PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
  185. group=_PP_CPU_GROUP,
  186. device=_LOCAL_RANK,
  187. )
  188. def ensure_model_parallel_initialized(
  189. tensor_model_parallel_size: int,
  190. pipeline_model_parallel_size: int,
  191. backend: Optional[str] = None,
  192. ) -> None:
  193. """Helper to initialize model parallel groups if they are not initialized,
  194. or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
  195. values if the model parallel groups are initialized.
  196. """
  197. # get the backend of _DEVICE_WORLD_GROUP
  198. backend = backend or torch.distributed.get_backend()
  199. if not model_parallel_is_initialized():
  200. initialize_model_parallel(tensor_model_parallel_size,
  201. pipeline_model_parallel_size, backend)
  202. return
  203. assert (
  204. get_tensor_model_parallel_world_size() == tensor_model_parallel_size
  205. ), ("tensor parallel group already initialized, but of unexpected size: "
  206. f"{get_tensor_model_parallel_world_size()=} vs. "
  207. f"{tensor_model_parallel_size=}")
  208. assert (get_pipeline_model_parallel_world_size(
  209. ) == pipeline_model_parallel_size), (
  210. "pipeline parallel group already initialized, but of unexpected size: "
  211. f"{get_pipeline_model_parallel_world_size()=} vs. "
  212. f"{pipeline_model_parallel_size=}")
  213. def model_parallel_is_initialized():
  214. """Check if tensor and pipeline parallel groups are initialized."""
  215. return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None)
  216. def get_cpu_world_group():
  217. """Get the CPU world group."""
  218. assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
  219. return _CPU_WORLD_GROUP
  220. def get_tensor_model_parallel_group():
  221. """Get the tensor model parallel group the caller rank belongs to."""
  222. assert _TP_DEVICE_GROUP is not None, (
  223. "tensor model parallel group is not initialized")
  224. return _TP_DEVICE_GROUP
  225. def get_tensor_model_parallel_cpu_group():
  226. """Get the tensor model parallel cpu group the caller rank belongs to."""
  227. assert _TP_CPU_GROUP is not None, (
  228. "tensor model parallel cpu group is not initialized")
  229. return _TP_CPU_GROUP
  230. def get_pipeline_model_parallel_group():
  231. """Get the pipeline model parallel group the caller rank belongs to."""
  232. assert _PP_DEVICE_GROUP is not None, (
  233. "pipeline model parallel group is not initialized")
  234. return _PP_DEVICE_GROUP
  235. def get_pipeline_model_parallel_cpu_group():
  236. """Get the pipeline model parallel cpu group the caller rank belongs to."""
  237. assert _PP_CPU_GROUP is not None, (
  238. "pipeline model parallel cpu group is not initialized")
  239. return _PP_CPU_GROUP
  240. def get_tensor_model_parallel_world_size():
  241. """Return world size for the tensor model parallel group."""
  242. return torch.distributed.get_world_size(
  243. group=get_tensor_model_parallel_group())
  244. def get_pipeline_model_parallel_world_size():
  245. """Return world size for the pipeline model parallel group."""
  246. return torch.distributed.get_world_size(
  247. group=get_pipeline_model_parallel_group())
  248. def get_tensor_model_parallel_rank():
  249. """Return my rank for the tensor model parallel group."""
  250. return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
  251. def get_pipeline_model_parallel_rank():
  252. """Return my rank for the pipeline model parallel group."""
  253. return torch.distributed.get_rank(
  254. group=get_pipeline_model_parallel_group())
  255. def get_tensor_model_parallel_src_rank():
  256. """Calculate the global rank corresponding to the first local rank
  257. in the tensor model parallel group."""
  258. global_rank = torch.distributed.get_rank()
  259. local_world_size = get_tensor_model_parallel_world_size()
  260. return (global_rank // local_world_size) * local_world_size
  261. def get_pipeline_model_parallel_first_rank():
  262. """Return the global rank of the first process in the pipeline for the
  263. current tensor parallel group"""
  264. assert _PP_GLOBAL_RANKS is not None, (
  265. "Pipeline parallel group is not initialized")
  266. return _PP_GLOBAL_RANKS[0]
  267. def get_pipeline_model_parallel_last_rank():
  268. """Return the global rank of the last process in the pipeline for the
  269. current tensor parallel group"""
  270. assert _PP_GLOBAL_RANKS is not None, (
  271. "Pipeline parallel group is not initialized")
  272. last_rank_local = get_pipeline_model_parallel_world_size() - 1
  273. return _PP_GLOBAL_RANKS[last_rank_local]
  274. def get_pipeline_model_parallel_next_rank():
  275. """Return the global rank that follows the caller in the pipeline"""
  276. assert _PP_GLOBAL_RANKS is not None, (
  277. "Pipeline parallel group is not initialized")
  278. rank_in_pipeline = get_pipeline_model_parallel_rank()
  279. world_size = get_pipeline_model_parallel_world_size()
  280. return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
  281. def get_pipeline_model_parallel_prev_rank():
  282. """Return the global rank that precedes the caller in the pipeline"""
  283. assert _PP_GLOBAL_RANKS is not None, (
  284. "Pipeline parallel group is not initialized")
  285. rank_in_pipeline = get_pipeline_model_parallel_rank()
  286. world_size = get_pipeline_model_parallel_world_size()
  287. return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
  288. def destroy_model_parallel():
  289. """Set the groups to none and destroy them."""
  290. global _TP_DEVICE_GROUP
  291. if _TP_DEVICE_GROUP:
  292. torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
  293. _TP_DEVICE_GROUP = None
  294. global _TP_CPU_GROUP
  295. if _TP_CPU_GROUP:
  296. torch.distributed.destroy_process_group(_TP_CPU_GROUP)
  297. _TP_CPU_GROUP = None
  298. global _TP_PYNCCL_COMMUNICATOR
  299. _TP_PYNCCL_COMMUNICATOR = None
  300. global _PP_DEVICE_GROUP
  301. if _PP_DEVICE_GROUP:
  302. torch.distributed.destroy_process_group(_PP_DEVICE_GROUP)
  303. _PP_DEVICE_GROUP = None
  304. global _PP_GLOBAL_RANKS
  305. _PP_GLOBAL_RANKS = None