parallel_state.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. """Aphrodite distributed state.
  7. It takes over the control of the distributed environment from PyTorch.
  8. The typical workflow is:
  9. - call `init_distributed_environment` to initialize the distributed environment.
  10. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
  11. initialize the model parallel groups.
  12. - any code dealing with the distributed stuff
  13. - call `destroy_model_parallel` to destroy the model parallel groups.
  14. - call `destroy_distributed_environment` to destroy the distributed environment.
  15. If you only need to use the distributed environment without model/pipeline
  16. parallelism, you can skip the model parallel initialization and destruction
  17. steps.
  18. """
  19. import contextlib
  20. import os
  21. from collections import namedtuple
  22. from contextlib import contextmanager, nullcontext
  23. from dataclasses import dataclass
  24. from multiprocessing import shared_memory
  25. from typing import Any, Dict, List, Optional, Tuple, Union
  26. from unittest.mock import patch
  27. import torch
  28. from loguru import logger
  29. from torch.distributed import Backend, ProcessGroup
  30. @dataclass
  31. class GraphCaptureContext:
  32. stream: torch.cuda.Stream
  33. TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
  34. def _split_tensor_dict(
  35. tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
  36. ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
  37. """Split the tensor dictionary into two parts:
  38. 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
  39. by its metadata.
  40. 2. A list of tensors.
  41. """
  42. metadata_list = []
  43. tensor_list = []
  44. for key, value in tensor_dict.items():
  45. if isinstance(value, torch.Tensor):
  46. # Note: we cannot use `value.device` here,
  47. # because it contains not only the device type but also the device
  48. # index (e.g. "cuda:0"). We only need the device type.
  49. # receiving side will set the device index.
  50. device = value.device.type
  51. metadata_list.append(
  52. (key, TensorMetadata(device, value.dtype, value.size())))
  53. tensor_list.append(value)
  54. else:
  55. metadata_list.append((key, value))
  56. return metadata_list, tensor_list
  57. class GroupCoordinator:
  58. """
  59. PyTorch ProcessGroup wrapper for a group of processes.
  60. PyTorch ProcessGroup is bound to one specific communication backend,
  61. e.g. NCCL, Gloo, MPI, etc.
  62. GroupCoordinator takes charge of all the communication operations among
  63. the processes in the group. It can route the communication to
  64. a specific implementation (e.g. switch allreduce implementation
  65. based on the tensor size and cuda graph mode).
  66. """
  67. # available attributes:
  68. rank: int # global rank
  69. ranks: List[int] # global ranks in the group
  70. world_size: int # size of the group
  71. # difference between `local_rank` and `rank_in_group`:
  72. # if we have a group of size 4 across two nodes:
  73. # Process | Node | Rank | Local Rank | Rank in Group
  74. # 0 | 0 | 0 | 0 | 0
  75. # 1 | 0 | 1 | 1 | 1
  76. # 2 | 1 | 2 | 0 | 2
  77. # 3 | 1 | 3 | 1 | 3
  78. local_rank: int # local rank used to assign devices
  79. rank_in_group: int # rank inside the group
  80. cpu_group: ProcessGroup # group for CPU communication
  81. device_group: ProcessGroup # group for device communication
  82. use_pynccl: bool # a hint of whether to use PyNccl
  83. use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
  84. # communicators are only created for world size > 1
  85. pynccl_comm: Optional[Any] # PyNccl communicator
  86. ca_comm: Optional[Any] # Custom allreduce communicator
  87. def __init__(
  88. self,
  89. group_ranks: List[List[int]],
  90. local_rank: int,
  91. torch_distributed_backend: Union[str, Backend],
  92. use_pynccl: bool,
  93. use_custom_allreduce: bool,
  94. ):
  95. self.rank = torch.distributed.get_rank()
  96. self.local_rank = local_rank
  97. self.device_group = None
  98. self.cpu_group = None
  99. for ranks in group_ranks:
  100. device_group = torch.distributed.new_group(
  101. ranks, backend=torch_distributed_backend)
  102. # a group with `gloo` backend, to allow direct coordination between
  103. # processes through the CPU.
  104. cpu_group = torch.distributed.new_group(ranks, backend="gloo")
  105. if self.rank in ranks:
  106. self.ranks = ranks
  107. self.world_size = len(ranks)
  108. self.rank_in_group = ranks.index(self.rank)
  109. self.device_group = device_group
  110. self.cpu_group = cpu_group
  111. assert self.cpu_group is not None
  112. assert self.device_group is not None
  113. if torch.cuda.is_available():
  114. self.device = torch.device(f"cuda:{local_rank}")
  115. else:
  116. self.device = torch.device("cpu")
  117. self.use_pynccl = use_pynccl
  118. self.use_custom_allreduce = use_custom_allreduce
  119. # lazy import to avoid documentation build error
  120. from aphrodite.distributed.device_communicators.custom_all_reduce import \
  121. CustomAllreduce # noqa: E501
  122. from aphrodite.distributed.device_communicators.pynccl import \
  123. PyNcclCommunicator
  124. self.pynccl_comm: Optional[PyNcclCommunicator]
  125. if use_pynccl and self.world_size > 1:
  126. self.pynccl_comm = PyNcclCommunicator(
  127. group=self.cpu_group,
  128. device=self.device,
  129. )
  130. else:
  131. self.pynccl_comm = None
  132. self.ca_comm: Optional[CustomAllreduce]
  133. if use_custom_allreduce and self.world_size > 1:
  134. # Initialize a custom fast all-reduce implementation.
  135. self.ca_comm = CustomAllreduce(
  136. group=self.cpu_group,
  137. device=self.device,
  138. )
  139. else:
  140. self.ca_comm = None
  141. @property
  142. def first_rank(self):
  143. """Return the global rank of the first process in the group"""
  144. return self.ranks[0]
  145. @property
  146. def last_rank(self):
  147. """Return the global rank of the last process in the group"""
  148. return self.ranks[-1]
  149. @property
  150. def next_rank(self):
  151. """Return the global rank of the process that follows the caller"""
  152. rank_in_group = self.rank_in_group
  153. world_size = self.world_size
  154. return self.ranks[(rank_in_group + 1) % world_size]
  155. @property
  156. def prev_rank(self):
  157. """Return the global rank of the process that precedes the caller"""
  158. rank_in_group = self.rank_in_group
  159. world_size = self.world_size
  160. return self.ranks[(rank_in_group - 1) % world_size]
  161. @contextmanager
  162. def graph_capture(
  163. self, graph_capture_context: Optional[GraphCaptureContext] = None):
  164. if graph_capture_context is None:
  165. stream = torch.cuda.Stream()
  166. graph_capture_context = GraphCaptureContext(stream)
  167. else:
  168. stream = graph_capture_context.stream
  169. ca_comm = self.ca_comm
  170. maybe_ca_context = nullcontext(
  171. ) if ca_comm is None else ca_comm.capture()
  172. with torch.cuda.stream(stream), maybe_ca_context:
  173. # In graph mode, we have to be very careful about the collective
  174. # operations. The current status is:
  175. # allreduce \ Mode | Eager | Graph |
  176. # --------------------------------------------
  177. # custom allreduce | enabled | enabled |
  178. # PyNccl | disabled| enabled |
  179. # torch.distributed | enabled | disabled|
  180. #
  181. # Note that custom allreduce will have a runtime check, if the
  182. # tensor size is too large, it will fallback to the next
  183. # available option.
  184. # In summary: When using CUDA graph, we use
  185. # either custom all-reduce kernel or pynccl. When not using
  186. # CUDA graph, we use either custom all-reduce kernel or
  187. # PyTorch NCCL. We always prioritize using custom all-reduce
  188. # kernel but fall back to PyTorch or pynccl if it is
  189. # disabled or not supported.
  190. pynccl_comm = self.pynccl_comm
  191. maybe_pynccl_context: Any
  192. if not pynccl_comm:
  193. maybe_pynccl_context = nullcontext()
  194. else:
  195. maybe_pynccl_context = pynccl_comm.change_state(
  196. enable=True, stream=torch.cuda.current_stream())
  197. with maybe_pynccl_context:
  198. yield graph_capture_context
  199. def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
  200. """
  201. NOTE: This operation will be applied in-place or out-of-place.
  202. Always assume this function modifies its input, but use the return
  203. value as the output.
  204. """
  205. ca_comm = self.ca_comm
  206. # Bypass the function if we are using only 1 GPU.
  207. if self.world_size == 1:
  208. return input_
  209. if ca_comm is not None:
  210. out = ca_comm.custom_all_reduce(input_)
  211. if out is not None:
  212. return out
  213. pynccl_comm = self.pynccl_comm
  214. if (pynccl_comm is not None and not pynccl_comm.disabled):
  215. pynccl_comm.all_reduce(input_)
  216. else:
  217. torch.distributed.all_reduce(input_, group=self.device_group)
  218. return input_
  219. def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
  220. world_size = self.world_size
  221. # Bypass the function if we are using only 1 GPU.
  222. if world_size == 1:
  223. return input_
  224. assert -input_.dim() <= dim < input_.dim(), (
  225. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  226. if dim < 0:
  227. # Convert negative dim to positive.
  228. dim += input_.dim()
  229. input_size = input_.size()
  230. # Allocate output tensor.
  231. output_tensor = torch.empty((world_size, ) + input_size,
  232. dtype=input_.dtype,
  233. device=input_.device)
  234. # All-gather.
  235. torch.distributed.all_gather_into_tensor(output_tensor,
  236. input_,
  237. group=self.device_group)
  238. # Reshape
  239. output_tensor = output_tensor.movedim(0, dim)
  240. output_tensor = output_tensor.reshape(input_size[:dim] +
  241. (world_size *
  242. input_size[dim], ) +
  243. input_size[dim + 1:])
  244. return output_tensor
  245. def gather(self,
  246. input_: torch.Tensor,
  247. dst: int = 0,
  248. dim: int = -1) -> torch.Tensor:
  249. """
  250. NOTE: We assume that the input tensor is on the same device across
  251. all the ranks.
  252. NOTE: `dst` is the local rank of the destination rank.
  253. """
  254. world_size = self.world_size
  255. # Bypass the function if we are using only 1 GPU.
  256. if world_size == 1:
  257. return input_
  258. assert -input_.dim() <= dim < input_.dim(), (
  259. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  260. if dim < 0:
  261. # Convert negative dim to positive.
  262. dim += input_.dim()
  263. # Allocate output tensor.
  264. if self.rank_in_group == dst:
  265. gather_list = [torch.empty_like(input_) for _ in range(world_size)]
  266. else:
  267. gather_list = None
  268. # Gather.
  269. torch.distributed.gather(input_,
  270. gather_list,
  271. dst=self.ranks[dst],
  272. group=self.device_group)
  273. if self.rank_in_group == dst:
  274. output_tensor = torch.cat(gather_list, dim=dim)
  275. else:
  276. output_tensor = None
  277. return output_tensor
  278. def broadcast(self, input_: torch.Tensor, src: int = 0):
  279. """Broadcast the input tensor.
  280. NOTE: `src` is the local rank of the source rank.
  281. """
  282. assert src < self.world_size, f"Invalid src rank ({src})"
  283. # Bypass the function if we are using only 1 GPU.
  284. if self.world_size == 1:
  285. return input_
  286. # Broadcast.
  287. torch.distributed.broadcast(input_,
  288. src=self.ranks[src],
  289. group=self.device_group)
  290. return input_
  291. def broadcast_object_list(self,
  292. obj_list: List[Any],
  293. src: int = 0,
  294. group: Optional[ProcessGroup] = None):
  295. """Broadcast the input object list.
  296. NOTE: `src` is the local rank of the source rank.
  297. """
  298. assert src < self.world_size, f"Invalid src rank ({src})"
  299. # Bypass the function if we are using only 1 GPU.
  300. if self.world_size == 1:
  301. return obj_list
  302. # Broadcast.
  303. torch.distributed.broadcast_object_list(obj_list,
  304. src=self.ranks[src],
  305. group=self.device_group)
  306. return obj_list
  307. def broadcast_tensor_dict(
  308. self,
  309. tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
  310. src: int = 0,
  311. group: Optional[ProcessGroup] = None,
  312. metadata_group: Optional[ProcessGroup] = None
  313. ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
  314. """Broadcast the input tensor dictionary.
  315. NOTE: `src` is the local rank of the source rank.
  316. """
  317. # Bypass the function if we are using only 1 GPU.
  318. if (not torch.distributed.is_initialized() or self.world_size == 1):
  319. return tensor_dict
  320. group = self.device_group
  321. metadata_group = self.cpu_group
  322. assert src < self.world_size, f"Invalid src rank ({src})"
  323. src = self.ranks[src]
  324. rank = self.rank
  325. if rank == src:
  326. metadata_list: List[Tuple[Any, Any]] = []
  327. assert isinstance(
  328. tensor_dict,
  329. dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
  330. metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
  331. # `metadata_list` lives in CPU memory.
  332. # `broadcast_object_list` has serialization & deserialization,
  333. # all happening on CPU. Therefore, we can use the CPU group.
  334. torch.distributed.broadcast_object_list([metadata_list],
  335. src=src,
  336. group=metadata_group)
  337. async_handles = []
  338. for tensor in tensor_list:
  339. if tensor.numel() == 0:
  340. # Skip broadcasting empty tensors.
  341. continue
  342. if tensor.is_cpu:
  343. # use metadata_group for CPU tensors
  344. handle = torch.distributed.broadcast(tensor,
  345. src=src,
  346. group=metadata_group,
  347. async_op=True)
  348. else:
  349. # use group for GPU tensors
  350. handle = torch.distributed.broadcast(tensor,
  351. src=src,
  352. group=group,
  353. async_op=True)
  354. async_handles.append(handle)
  355. for async_handle in async_handles:
  356. async_handle.wait()
  357. else:
  358. recv_metadata_list = [None]
  359. torch.distributed.broadcast_object_list(recv_metadata_list,
  360. src=src,
  361. group=metadata_group)
  362. assert recv_metadata_list[0] is not None
  363. tensor_dict = {}
  364. async_handles = []
  365. for key, value in recv_metadata_list[0]:
  366. if isinstance(value, TensorMetadata):
  367. tensor = torch.empty(value.size,
  368. dtype=value.dtype,
  369. device=value.device)
  370. if tensor.numel() == 0:
  371. # Skip broadcasting empty tensors.
  372. tensor_dict[key] = tensor
  373. continue
  374. if tensor.is_cpu:
  375. # use metadata_group for CPU tensors
  376. handle = torch.distributed.broadcast(
  377. tensor,
  378. src=src,
  379. group=metadata_group,
  380. async_op=True)
  381. else:
  382. # use group for GPU tensors
  383. handle = torch.distributed.broadcast(tensor,
  384. src=src,
  385. group=group,
  386. async_op=True)
  387. async_handles.append(handle)
  388. tensor_dict[key] = tensor
  389. else:
  390. tensor_dict[key] = value
  391. for async_handle in async_handles:
  392. async_handle.wait()
  393. return tensor_dict
  394. def barrier(self):
  395. """Barrier synchronization among the group.
  396. NOTE: don't use `device_group` here! `barrier` in NCCL is
  397. terrible because it is internally a broadcast operation with
  398. secretly created GPU tensors. It is easy to mess up the current
  399. device. Use the CPU group instead.
  400. """
  401. torch.distributed.barrier(group=self.cpu_group)
  402. def destroy(self):
  403. if self.device_group is not None:
  404. torch.distributed.destroy_process_group(self.device_group)
  405. self.device_group = None
  406. if self.cpu_group is not None:
  407. torch.distributed.destroy_process_group(self.cpu_group)
  408. self.cpu_group = None
  409. if self.pynccl_comm is not None:
  410. self.pynccl_comm = None
  411. if self.ca_comm is not None:
  412. self.ca_comm = None
  413. _WORLD: Optional[GroupCoordinator] = None
  414. def get_world_group() -> GroupCoordinator:
  415. assert _WORLD is not None, ("world group is not initialized")
  416. return _WORLD
  417. _TP: Optional[GroupCoordinator] = None
  418. def get_tp_group() -> GroupCoordinator:
  419. assert _TP is not None, ("tensor model parallel group is not initialized")
  420. return _TP
  421. # kept for backward compatibility
  422. get_tensor_model_parallel_group = get_tp_group
  423. _PP: Optional[GroupCoordinator] = None
  424. def get_pp_group() -> GroupCoordinator:
  425. assert _PP is not None, (
  426. "pipeline model parallel group is not initialized")
  427. return _PP
  428. # kept for backward compatibility
  429. get_pipeline_model_parallel_group = get_pp_group
  430. @contextmanager
  431. def graph_capture():
  432. """
  433. `graph_capture` is a context manager which should surround the code that
  434. is capturing the CUDA graph. Its main purpose is to ensure that the
  435. some operations will be run after the graph is captured, before the graph
  436. is replayed. It returns a `GraphCaptureContext` object which contains the
  437. necessary data for the graph capture. Currently, it only contains the
  438. stream that the graph capture is running on. This stream is set to the
  439. current CUDA stream when the context manager is entered and reset to the
  440. default stream when the context manager is exited. This is to ensure that
  441. the graph capture is running on a separate stream from the default stream,
  442. in order to explicitly distinguish the kernels to capture
  443. from other kernels possibly launched on background in the default stream.
  444. """
  445. with get_tp_group().graph_capture() as context, get_pp_group(
  446. ).graph_capture(context):
  447. yield context
  448. _ENABLE_CUSTOM_ALL_REDUCE = True
  449. def set_custom_all_reduce(enable: bool):
  450. global _ENABLE_CUSTOM_ALL_REDUCE
  451. _ENABLE_CUSTOM_ALL_REDUCE = enable
  452. def init_distributed_environment(
  453. world_size: int = -1,
  454. rank: int = -1,
  455. distributed_init_method: str = "env://",
  456. local_rank: int = -1,
  457. backend: str = "nccl",
  458. ):
  459. logger.debug(
  460. "world_size=%d rank=%d local_rank=%d "
  461. "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
  462. distributed_init_method, backend)
  463. if not torch.distributed.is_initialized():
  464. assert distributed_init_method is not None, (
  465. "distributed_init_method must be provided when initializing "
  466. "distributed environment")
  467. # this backend is used for WORLD
  468. torch.distributed.init_process_group(
  469. backend=backend,
  470. init_method=distributed_init_method,
  471. world_size=world_size,
  472. rank=rank)
  473. # set the local rank
  474. # local_rank is not available in torch ProcessGroup,
  475. # see https://github.com/pytorch/pytorch/issues/122816
  476. if local_rank == -1:
  477. # local rank not set, this usually happens in single-node
  478. # setting, where we can use rank as local rank
  479. if distributed_init_method == "env://":
  480. local_rank = os.getenv("LOCAL_RANK", rank)
  481. else:
  482. local_rank = rank
  483. global _WORLD
  484. if _WORLD is None:
  485. ranks = list(range(torch.distributed.get_world_size()))
  486. _WORLD = GroupCoordinator(
  487. group_ranks=[ranks],
  488. local_rank=local_rank,
  489. torch_distributed_backend=backend,
  490. use_pynccl=False,
  491. use_custom_allreduce=False,
  492. )
  493. else:
  494. assert _WORLD.world_size == torch.distributed.get_world_size(), (
  495. "world group already initialized with a different world size")
  496. def initialize_model_parallel(
  497. tensor_model_parallel_size: int = 1,
  498. pipeline_model_parallel_size: int = 1,
  499. backend: Optional[str] = None,
  500. ) -> None:
  501. """
  502. Initialize model parallel groups.
  503. Arguments:
  504. tensor_model_parallel_size: number of GPUs used for tensor model
  505. parallelism.
  506. pipeline_model_parallel_size: number of GPUs used for pipeline model
  507. parallelism.
  508. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  509. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  510. the model pipeline. The present function will
  511. create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
  512. 4 tensor model-parallel groups:
  513. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  514. 2 pipeline model-parallel groups:
  515. [g0, g2, g4, g6], [g1, g3, g5, g7]
  516. Note that for efficiency, the caller should make sure adjacent ranks
  517. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  518. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  519. ranks 8 to 15 belong to the second box.
  520. """
  521. # Get world size and rank. Ensure some consistencies.
  522. assert torch.distributed.is_initialized()
  523. world_size: int = torch.distributed.get_world_size()
  524. backend = backend or torch.distributed.get_backend(
  525. get_world_group().device_group)
  526. if (world_size !=
  527. tensor_model_parallel_size * pipeline_model_parallel_size):
  528. raise RuntimeError(
  529. f"world_size ({world_size}) is not equal to "
  530. f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
  531. f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
  532. # Build the tensor model-parallel groups.
  533. num_tensor_model_parallel_groups: int = (world_size //
  534. tensor_model_parallel_size)
  535. global _TP
  536. assert _TP is None, ("tensor model parallel group is already initialized")
  537. group_ranks = []
  538. for i in range(num_tensor_model_parallel_groups):
  539. ranks = list(
  540. range(i * tensor_model_parallel_size,
  541. (i + 1) * tensor_model_parallel_size))
  542. group_ranks.append(ranks)
  543. _TP = GroupCoordinator(
  544. group_ranks=group_ranks,
  545. local_rank=get_world_group().local_rank,
  546. torch_distributed_backend=backend,
  547. use_pynccl=True,
  548. use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
  549. )
  550. # Build the pipeline model-parallel groups.
  551. num_pipeline_model_parallel_groups: int = (world_size //
  552. pipeline_model_parallel_size)
  553. global _PP
  554. assert _PP is None, (
  555. "pipeline model parallel group is already initialized")
  556. group_ranks = []
  557. for i in range(num_pipeline_model_parallel_groups):
  558. ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
  559. group_ranks.append(ranks)
  560. _PP = GroupCoordinator(
  561. group_ranks=group_ranks,
  562. local_rank=get_world_group().local_rank,
  563. torch_distributed_backend=backend,
  564. use_pynccl=True,
  565. use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
  566. )
  567. def ensure_model_parallel_initialized(
  568. tensor_model_parallel_size: int,
  569. pipeline_model_parallel_size: int,
  570. backend: Optional[str] = None,
  571. ) -> None:
  572. """Helper to initialize model parallel groups if they are not initialized,
  573. or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
  574. values if the model parallel groups are initialized.
  575. """
  576. backend = backend or torch.distributed.get_backend(
  577. get_world_group().device_group)
  578. if not model_parallel_is_initialized():
  579. initialize_model_parallel(tensor_model_parallel_size,
  580. pipeline_model_parallel_size, backend)
  581. return
  582. assert (
  583. get_tensor_model_parallel_world_size() == tensor_model_parallel_size
  584. ), ("tensor parallel group already initialized, but of unexpected size: "
  585. f"{get_tensor_model_parallel_world_size()=} vs. "
  586. f"{tensor_model_parallel_size=}")
  587. pp_world_size = get_pp_group().world_size
  588. assert (pp_world_size == pipeline_model_parallel_size), (
  589. "pipeline parallel group already initialized, but of unexpected size: "
  590. f"{pp_world_size=} vs. "
  591. f"{pipeline_model_parallel_size=}")
  592. def model_parallel_is_initialized():
  593. """Check if tensor and pipeline parallel groups are initialized."""
  594. return (_TP is not None and _PP is not None)
  595. def get_tensor_model_parallel_world_size():
  596. """Return world size for the tensor model parallel group."""
  597. return get_tp_group().world_size
  598. def get_tensor_model_parallel_rank():
  599. """Return my rank for the tensor model parallel group."""
  600. return get_tp_group().rank_in_group
  601. def destroy_model_parallel():
  602. """Set the groups to none and destroy them."""
  603. global _TP
  604. if _TP:
  605. _TP.destroy()
  606. _TP = None
  607. global _PP
  608. if _PP:
  609. _PP.destroy()
  610. _PP = None
  611. def destroy_distributed_environment():
  612. global _WORLD
  613. if _WORLD:
  614. _WORLD.destroy()
  615. _WORLD = None
  616. if torch.distributed.is_initialized():
  617. torch.distributed.destroy_process_group()
  618. def is_in_the_same_node(pg: ProcessGroup):
  619. """
  620. This is a collective operation that checks if all processes in the group
  621. are in the same node. It tests if all processes are attached to the same
  622. memory system (shared access to shared memory).
  623. """
  624. assert torch.distributed.get_backend(
  625. pg) != torch.distributed.Backend.NCCL, (
  626. "is_in_the_same_node should be tested with a non-NCCL group.")
  627. # local rank inside the group
  628. rank = torch.distributed.get_rank(group=pg)
  629. world_size = torch.distributed.get_world_size(group=pg)
  630. # local tensor in each process to store the result
  631. is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
  632. # global ranks of the processes in the group
  633. ranks = torch.distributed.get_process_group_ranks(pg)
  634. magic_message = b"magic_message"
  635. shm = None
  636. try:
  637. with contextlib.suppress(OSError):
  638. if rank == 0:
  639. # create a shared memory segment
  640. shm = shared_memory.SharedMemory(create=True, size=128)
  641. shm.buf[:len(magic_message)] = magic_message
  642. torch.distributed.broadcast_object_list([shm.name],
  643. src=ranks[0],
  644. group=pg)
  645. is_in_the_same_node[0] = 1
  646. else:
  647. # try to open the shared memory segment
  648. recv = [None]
  649. torch.distributed.broadcast_object_list(recv,
  650. src=ranks[0],
  651. group=pg)
  652. name = recv[0]
  653. # fix to https://stackoverflow.com/q/62748654/9191338
  654. # Python incorrectly tracks shared memory even if it is not
  655. # created by the process. The following patch is a workaround.
  656. with patch("multiprocessing.resource_tracker.register",
  657. lambda *args, **kwargs: None):
  658. shm = shared_memory.SharedMemory(name=name)
  659. if shm.buf[:len(magic_message)] == magic_message:
  660. is_in_the_same_node[rank] = 1
  661. except Exception as e:
  662. logger.error("Error ignored in is_in_the_same_node: %s", e)
  663. finally:
  664. if shm:
  665. shm.close()
  666. torch.distributed.barrier(group=pg)
  667. # clean up the shared memory segment
  668. with contextlib.suppress(OSError):
  669. if rank == 0 and shm:
  670. shm.unlink()
  671. torch.distributed.all_reduce(is_in_the_same_node, group=pg)
  672. return is_in_the_same_node.sum().item() == world_size