parallel_state.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. """Aphrodite distributed state.
  7. It takes over the control of the distributed environment from PyTorch.
  8. The typical workflow is:
  9. - call `init_distributed_environment` to initialize the distributed environment.
  10. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
  11. initialize the model parallel groups.
  12. - any code dealing with the distributed stuff
  13. - call `destroy_model_parallel` to destroy the model parallel groups.
  14. - call `destroy_distributed_environment` to destroy the distributed environment.
  15. If you only need to use the distributed environment without model/pipeline
  16. parallelism, you can skip the model parallel initialization and destruction
  17. steps.
  18. """
  19. import contextlib
  20. import pickle
  21. import sys
  22. import weakref
  23. from collections import namedtuple
  24. from contextlib import contextmanager, nullcontext
  25. from dataclasses import dataclass
  26. from multiprocessing import shared_memory
  27. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  28. from unittest.mock import patch
  29. import torch
  30. import torch.distributed
  31. from loguru import logger
  32. from torch.distributed import Backend, ProcessGroup
  33. import aphrodite.common.envs as envs
  34. @dataclass
  35. class GraphCaptureContext:
  36. stream: torch.cuda.Stream
  37. TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
  38. def _split_tensor_dict(
  39. tensor_dict: Dict[str, Union[torch.Tensor, Any]]
  40. ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
  41. """Split the tensor dictionary into two parts:
  42. 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
  43. by its metadata.
  44. 2. A list of tensors.
  45. """
  46. metadata_list: List[Tuple[str, Any]] = []
  47. tensor_list: List[torch.Tensor] = []
  48. for key, value in tensor_dict.items():
  49. if isinstance(value, torch.Tensor):
  50. # Note: we cannot use `value.device` here,
  51. # because it contains not only the device type but also the device
  52. # index (e.g. "cuda:0"). We only need the device type.
  53. # receiving side will set the device index.
  54. device = value.device.type
  55. metadata_list.append(
  56. (key, TensorMetadata(device, value.dtype, value.size())))
  57. tensor_list.append(value)
  58. else:
  59. metadata_list.append((key, value))
  60. return metadata_list, tensor_list
  61. _group_name_counter: Dict[str, int] = {}
  62. def _get_unique_name(name: str) -> str:
  63. """Get a unique name for the group.
  64. Example:
  65. _get_unique_name("tp") -> "tp:0"
  66. _get_unique_name("tp") -> "tp:1"
  67. """
  68. if name not in _group_name_counter:
  69. _group_name_counter[name] = 0
  70. newname = f"{name}:{_group_name_counter[name]}"
  71. _group_name_counter[name] += 1
  72. return newname
  73. _groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
  74. def _register_group(group: "GroupCoordinator") -> None:
  75. # looks like Python 3.8 does not understand `ReferenceType`
  76. _groups[group.unique_name] = weakref.ref(group) # type: ignore
  77. @torch.library.custom_op("aphrodite::inplace_all_reduce",
  78. mutates_args=["tensor"])
  79. def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
  80. assert group_name in _groups, f"Group {group_name} is not found."
  81. group = _groups[group_name]()
  82. if group is None:
  83. raise ValueError(f"Group {group_name} is destroyed.")
  84. group._all_reduce(tensor)
  85. @inplace_all_reduce.register_fake
  86. def _(tensor: torch.Tensor, group_name: str) -> None:
  87. return
  88. @torch.library.custom_op("aphrodite::outplace_all_reduce", mutates_args=[])
  89. def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
  90. assert group_name in _groups, f"Group {group_name} is not found."
  91. group = _groups[group_name]()
  92. if group is None:
  93. raise ValueError(f"Group {group_name} is destroyed.")
  94. return group._all_reduce(tensor)
  95. @outplace_all_reduce.register_fake
  96. def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
  97. return torch.empty_like(tensor)
  98. class GroupCoordinator:
  99. """
  100. PyTorch ProcessGroup wrapper for a group of processes.
  101. PyTorch ProcessGroup is bound to one specific communication backend,
  102. e.g. NCCL, Gloo, MPI, etc.
  103. GroupCoordinator takes charge of all the communication operations among
  104. the processes in the group. It can route the communication to
  105. a specific implementation (e.g. switch allreduce implementation
  106. based on the tensor size and cuda graph mode).
  107. """
  108. # available attributes:
  109. rank: int # global rank
  110. ranks: List[int] # global ranks in the group
  111. world_size: int # size of the group
  112. # difference between `local_rank` and `rank_in_group`:
  113. # if we have a group of size 4 across two nodes:
  114. # Process | Node | Rank | Local Rank | Rank in Group
  115. # 0 | 0 | 0 | 0 | 0
  116. # 1 | 0 | 1 | 1 | 1
  117. # 2 | 1 | 2 | 0 | 2
  118. # 3 | 1 | 3 | 1 | 3
  119. local_rank: int # local rank used to assign devices
  120. rank_in_group: int # rank inside the group
  121. cpu_group: ProcessGroup # group for CPU communication
  122. device_group: ProcessGroup # group for device communication
  123. use_pynccl: bool # a hint of whether to use PyNccl
  124. use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
  125. # communicators are only created for world size > 1
  126. pynccl_comm: Optional[Any] # PyNccl communicator
  127. ca_comm: Optional[Any] # Custom allreduce communicator
  128. mq_broadcaster: Optional[Any] # shared memory broadcaster
  129. def __init__(
  130. self,
  131. group_ranks: List[List[int]],
  132. local_rank: int,
  133. torch_distributed_backend: Union[str, Backend],
  134. use_pynccl: bool,
  135. use_custom_allreduce: bool,
  136. use_tpu_communicator: bool,
  137. use_message_queue_broadcaster: bool = False,
  138. group_name: Optional[str] = None,
  139. ):
  140. group_name = group_name or "anonymous"
  141. self.unique_name = _get_unique_name(group_name)
  142. _register_group(self)
  143. self.rank = torch.distributed.get_rank()
  144. self.local_rank = local_rank
  145. self.device_group = None
  146. self.cpu_group = None
  147. for ranks in group_ranks:
  148. device_group = torch.distributed.new_group(
  149. ranks, backend=torch_distributed_backend)
  150. # a group with `gloo` backend, to allow direct coordination between
  151. # processes through the CPU.
  152. cpu_group = torch.distributed.new_group(ranks, backend="gloo")
  153. if self.rank in ranks:
  154. self.ranks = ranks
  155. self.world_size = len(ranks)
  156. self.rank_in_group = ranks.index(self.rank)
  157. self.device_group = device_group
  158. self.cpu_group = cpu_group
  159. assert self.cpu_group is not None
  160. assert self.device_group is not None
  161. if torch.cuda.is_available():
  162. self.device = torch.device(f"cuda:{local_rank}")
  163. else:
  164. self.device = torch.device("cpu")
  165. self.use_pynccl = use_pynccl
  166. self.use_custom_allreduce = use_custom_allreduce
  167. self.use_tpu_communicator = use_tpu_communicator
  168. # lazy import to avoid documentation build error
  169. from aphrodite.distributed.device_communicators.custom_all_reduce import ( # noqa: E501
  170. CustomAllreduce)
  171. from aphrodite.distributed.device_communicators.pynccl import (
  172. PyNcclCommunicator)
  173. self.pynccl_comm: Optional[PyNcclCommunicator] = None
  174. if use_pynccl and self.world_size > 1:
  175. self.pynccl_comm = PyNcclCommunicator(
  176. group=self.cpu_group,
  177. device=self.device,
  178. )
  179. self.ca_comm: Optional[CustomAllreduce] = None
  180. if use_custom_allreduce and self.world_size > 1:
  181. # Initialize a custom fast all-reduce implementation.
  182. self.ca_comm = CustomAllreduce(
  183. group=self.cpu_group,
  184. device=self.device,
  185. )
  186. from aphrodite.distributed.device_communicators.tpu_communicator import ( # noqa: E501
  187. TpuCommunicator)
  188. self.tpu_communicator: Optional[TpuCommunicator] = None
  189. if use_tpu_communicator and self.world_size > 1:
  190. self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
  191. from aphrodite.distributed.device_communicators.shm_broadcast import (
  192. MessageQueue)
  193. self.mq_broadcaster: Optional[MessageQueue] = None
  194. if use_message_queue_broadcaster and self.world_size > 1:
  195. self.mq_broadcaster = MessageQueue.create_from_process_group(
  196. self.cpu_group, 1 << 22, 6)
  197. @property
  198. def first_rank(self):
  199. """Return the global rank of the first process in the group"""
  200. return self.ranks[0]
  201. @property
  202. def last_rank(self):
  203. """Return the global rank of the last process in the group"""
  204. return self.ranks[-1]
  205. @property
  206. def is_first_rank(self):
  207. """Return whether the caller is the first process in the group"""
  208. return self.rank == self.first_rank
  209. @property
  210. def is_last_rank(self):
  211. """Return whether the caller is the last process in the group"""
  212. return self.rank == self.last_rank
  213. @property
  214. def next_rank(self):
  215. """Return the global rank of the process that follows the caller"""
  216. rank_in_group = self.rank_in_group
  217. world_size = self.world_size
  218. return self.ranks[(rank_in_group + 1) % world_size]
  219. @property
  220. def prev_rank(self):
  221. """Return the global rank of the process that precedes the caller"""
  222. rank_in_group = self.rank_in_group
  223. world_size = self.world_size
  224. return self.ranks[(rank_in_group - 1) % world_size]
  225. @contextmanager
  226. def graph_capture(
  227. self, graph_capture_context: Optional[GraphCaptureContext] = None):
  228. if graph_capture_context is None:
  229. stream = torch.cuda.Stream()
  230. graph_capture_context = GraphCaptureContext(stream)
  231. else:
  232. stream = graph_capture_context.stream
  233. ca_comm = self.ca_comm
  234. maybe_ca_context = nullcontext(
  235. ) if ca_comm is None else ca_comm.capture()
  236. # ensure all initialization operations complete before attempting to
  237. # capture the graph on another stream
  238. curr_stream = torch.cuda.current_stream()
  239. if curr_stream != stream:
  240. stream.wait_stream(curr_stream)
  241. with torch.cuda.stream(stream), maybe_ca_context:
  242. # In graph mode, we have to be very careful about the collective
  243. # operations. The current status is:
  244. # allreduce \ Mode | Eager | Graph |
  245. # --------------------------------------------
  246. # custom allreduce | enabled | enabled |
  247. # PyNccl | disabled| enabled |
  248. # torch.distributed | enabled | disabled|
  249. #
  250. # Note that custom allreduce will have a runtime check, if the
  251. # tensor size is too large, it will fallback to the next
  252. # available option.
  253. # In summary: When using CUDA graph, we use
  254. # either custom all-reduce kernel or pynccl. When not using
  255. # CUDA graph, we use either custom all-reduce kernel or
  256. # PyTorch NCCL. We always prioritize using custom all-reduce
  257. # kernel but fall back to PyTorch or pynccl if it is
  258. # disabled or not supported.
  259. pynccl_comm = self.pynccl_comm
  260. maybe_pynccl_context: Any
  261. if not pynccl_comm:
  262. maybe_pynccl_context = nullcontext()
  263. else:
  264. maybe_pynccl_context = pynccl_comm.change_state(
  265. enable=True, stream=torch.cuda.current_stream())
  266. with maybe_pynccl_context:
  267. yield graph_capture_context
  268. def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
  269. """
  270. User-facing all-reduce function before we actually call the
  271. all-reduce operation.
  272. We need this because Dynamo does not support passing an arbitrary
  273. object (`self` in this case) to a custom op. We need to pass the
  274. group name as a string, and then look up the group coordinator from
  275. the group name, dispatch the all-reduce operation to the group
  276. coordinator.
  277. In addition, PyTorch custom ops do not support mutation or returning
  278. a new tensor in the same op. So we need to figure out if the op is
  279. in-place or out-of-place ahead of time.
  280. """
  281. # Bypass the function if we are using only 1 GPU.
  282. if self.world_size == 1:
  283. return input_
  284. if self.tpu_communicator is not None and \
  285. not self.tpu_communicator.disabled:
  286. # TPU handles Dynamo with its own logic.
  287. return self._all_reduce(input_)
  288. if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
  289. return torch.ops.aphrodite.outplace_all_reduce(
  290. input_, group_name=self.unique_name)
  291. else:
  292. torch.ops.aphrodite.inplace_all_reduce(input_,
  293. group_name=self.unique_name)
  294. return input_
  295. def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
  296. """
  297. The actual all-reduce implementation.
  298. NOTE: This operation will be applied in-place or out-of-place.
  299. Always assume this function modifies its input, but use the return
  300. value as the output.
  301. """
  302. ca_comm = self.ca_comm
  303. # For TPUs, use TPU communicator.
  304. tpu_comm = self.tpu_communicator
  305. if tpu_comm is not None and not tpu_comm.disabled:
  306. return tpu_comm.all_reduce(input_)
  307. if ca_comm is not None:
  308. out = ca_comm.custom_all_reduce(input_)
  309. if out is not None:
  310. return out
  311. pynccl_comm = self.pynccl_comm
  312. if (pynccl_comm is not None and not pynccl_comm.disabled):
  313. pynccl_comm.all_reduce(input_)
  314. elif input_.is_cpu:
  315. import intel_extension_for_pytorch as ipex
  316. ipex.distributed.all_reduce(input_, group=self.device_group)
  317. else:
  318. torch.distributed.all_reduce(input_, group=self.device_group)
  319. return input_
  320. def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
  321. world_size = self.world_size
  322. # Bypass the function if we are using only 1 GPU.
  323. if world_size == 1:
  324. return input_
  325. assert -input_.dim() <= dim < input_.dim(), (
  326. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  327. # For TPUs, use TPU communicator.
  328. tpu_comm = self.tpu_communicator
  329. if tpu_comm is not None and not tpu_comm.disabled:
  330. return tpu_comm.all_gather(input_, dim)
  331. if dim < 0:
  332. # Convert negative dim to positive.
  333. dim += input_.dim()
  334. input_size = input_.size()
  335. # Allocate output tensor.
  336. output_tensor = torch.empty((world_size, ) + input_size,
  337. dtype=input_.dtype,
  338. device=input_.device)
  339. # All-gather.
  340. torch.distributed.all_gather_into_tensor(output_tensor,
  341. input_,
  342. group=self.device_group)
  343. # Reshape
  344. output_tensor = output_tensor.movedim(0, dim)
  345. output_tensor = output_tensor.reshape(input_size[:dim] +
  346. (world_size *
  347. input_size[dim], ) +
  348. input_size[dim + 1:])
  349. return output_tensor
  350. def gather(self,
  351. input_: torch.Tensor,
  352. dst: int = 0,
  353. dim: int = -1) -> Optional[torch.Tensor]:
  354. """
  355. NOTE: We assume that the input tensor is on the same device across
  356. all the ranks.
  357. NOTE: `dst` is the local rank of the destination rank.
  358. """
  359. world_size = self.world_size
  360. # Bypass the function if we are using only 1 GPU.
  361. if world_size == 1:
  362. return input_
  363. assert -input_.dim() <= dim < input_.dim(), (
  364. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  365. if dim < 0:
  366. # Convert negative dim to positive.
  367. dim += input_.dim()
  368. # Allocate output tensor.
  369. if self.rank_in_group == dst:
  370. gather_list = [torch.empty_like(input_) for _ in range(world_size)]
  371. else:
  372. gather_list = None
  373. # Gather.
  374. torch.distributed.gather(input_,
  375. gather_list,
  376. dst=self.ranks[dst],
  377. group=self.device_group)
  378. if self.rank_in_group == dst:
  379. output_tensor = torch.cat(gather_list, dim=dim)
  380. else:
  381. output_tensor = None
  382. return output_tensor
  383. def broadcast(self, input_: torch.Tensor, src: int = 0):
  384. """Broadcast the input tensor.
  385. NOTE: `src` is the local rank of the source rank.
  386. """
  387. assert src < self.world_size, f"Invalid src rank ({src})"
  388. # Bypass the function if we are using only 1 GPU.
  389. if self.world_size == 1:
  390. return input_
  391. # Broadcast.
  392. torch.distributed.broadcast(input_,
  393. src=self.ranks[src],
  394. group=self.device_group)
  395. return input_
  396. def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
  397. """Broadcast the input object.
  398. NOTE: `src` is the local rank of the source rank.
  399. """
  400. assert src < self.world_size, f"Invalid src rank ({src})"
  401. # Bypass the function if we are using only 1 GPU.
  402. if self.world_size == 1:
  403. return obj
  404. if self.mq_broadcaster is not None:
  405. assert src == 0, "Message queue broadcaster only supports src=0"
  406. return self.mq_broadcaster.broadcast_object(obj)
  407. if self.rank_in_group == src:
  408. torch.distributed.broadcast_object_list([obj],
  409. src=self.ranks[src],
  410. group=self.cpu_group)
  411. return obj
  412. else:
  413. recv = [None]
  414. torch.distributed.broadcast_object_list(recv,
  415. src=self.ranks[src],
  416. group=self.cpu_group)
  417. return recv[0]
  418. def broadcast_object_list(self,
  419. obj_list: List[Any],
  420. src: int = 0,
  421. group: Optional[ProcessGroup] = None):
  422. """Broadcast the input object list.
  423. NOTE: `src` is the local rank of the source rank.
  424. """
  425. assert src < self.world_size, f"Invalid src rank ({src})"
  426. # Bypass the function if we are using only 1 GPU.
  427. if self.world_size == 1:
  428. return obj_list
  429. # Broadcast.
  430. torch.distributed.broadcast_object_list(obj_list,
  431. src=self.ranks[src],
  432. group=self.device_group)
  433. return obj_list
  434. def send_object(self, obj: Any, dst: int) -> None:
  435. """Send the input object list to the destination rank."""
  436. """NOTE: `dst` is the local rank of the destination rank."""
  437. assert dst < self.world_size, f"Invalid dst rank ({dst})"
  438. assert dst != self.rank_in_group, (
  439. "Invalid destination rank. Destination rank is the same "
  440. "as the current rank.")
  441. # Serialize object to tensor and get the size as well
  442. object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
  443. size_tensor = torch.tensor([object_tensor.numel()],
  444. dtype=torch.long,
  445. device="cpu")
  446. # Send object size
  447. torch.distributed.send(size_tensor,
  448. dst=self.ranks[dst],
  449. group=self.cpu_group)
  450. # Send object
  451. torch.distributed.send(object_tensor,
  452. dst=self.ranks[dst],
  453. group=self.cpu_group)
  454. return None
  455. def recv_object(self, src: int) -> Any:
  456. """Receive the input object list from the source rank."""
  457. """NOTE: `src` is the local rank of the source rank."""
  458. assert src < self.world_size, f"Invalid src rank ({src})"
  459. assert src != self.rank_in_group, (
  460. "Invalid source rank. Source rank is the same as the current rank."
  461. )
  462. size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
  463. # Receive object size
  464. rank_size = torch.distributed.recv(size_tensor,
  465. src=self.ranks[src],
  466. group=self.cpu_group)
  467. # Tensor to receive serialized objects into.
  468. object_tensor = torch.empty( # type: ignore[call-overload]
  469. size_tensor.item(), # type: ignore[arg-type]
  470. dtype=torch.uint8,
  471. device="cpu")
  472. rank_object = torch.distributed.recv(object_tensor,
  473. src=self.ranks[src],
  474. group=self.cpu_group)
  475. assert rank_object == rank_size, (
  476. "Received object sender rank does not match the size sender rank.")
  477. obj = pickle.loads(object_tensor.numpy().tobytes())
  478. return obj
  479. def broadcast_tensor_dict(
  480. self,
  481. tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
  482. src: int = 0,
  483. group: Optional[ProcessGroup] = None,
  484. metadata_group: Optional[ProcessGroup] = None
  485. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  486. """Broadcast the input tensor dictionary.
  487. NOTE: `src` is the local rank of the source rank.
  488. """
  489. # Bypass the function if we are using only 1 GPU.
  490. if (not torch.distributed.is_initialized() or self.world_size == 1):
  491. return tensor_dict
  492. group = self.device_group
  493. metadata_group = self.cpu_group
  494. assert src < self.world_size, f"Invalid src rank ({src})"
  495. rank_in_group = self.rank_in_group
  496. if rank_in_group == src:
  497. metadata_list: List[Tuple[Any, Any]] = []
  498. assert isinstance(
  499. tensor_dict,
  500. dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
  501. metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
  502. # `metadata_list` lives in CPU memory.
  503. # `broadcast_object_list` has serialization & deserialization,
  504. # all happening on CPU. Therefore, we can use the CPU group.
  505. self.broadcast_object(metadata_list, src=src)
  506. async_handles = []
  507. for tensor in tensor_list:
  508. if tensor.numel() == 0:
  509. # Skip broadcasting empty tensors.
  510. continue
  511. if tensor.is_cpu:
  512. # use metadata_group for CPU tensors
  513. handle = torch.distributed.broadcast(tensor,
  514. src=self.ranks[src],
  515. group=metadata_group,
  516. async_op=True)
  517. else:
  518. # use group for GPU tensors
  519. handle = torch.distributed.broadcast(tensor,
  520. src=self.ranks[src],
  521. group=group,
  522. async_op=True)
  523. async_handles.append(handle)
  524. for async_handle in async_handles:
  525. async_handle.wait()
  526. else:
  527. metadata_list = self.broadcast_object(None, src=src)
  528. tensor_dict = {}
  529. async_handles = []
  530. for key, value in metadata_list:
  531. if isinstance(value, TensorMetadata):
  532. tensor = torch.empty(value.size,
  533. dtype=value.dtype,
  534. device=value.device)
  535. if tensor.numel() == 0:
  536. # Skip broadcasting empty tensors.
  537. tensor_dict[key] = tensor
  538. continue
  539. if tensor.is_cpu:
  540. # use metadata_group for CPU tensors
  541. handle = torch.distributed.broadcast(
  542. tensor,
  543. src=self.ranks[src],
  544. group=metadata_group,
  545. async_op=True)
  546. else:
  547. # use group for GPU tensors
  548. handle = torch.distributed.broadcast(
  549. tensor,
  550. src=self.ranks[src],
  551. group=group,
  552. async_op=True)
  553. async_handles.append(handle)
  554. tensor_dict[key] = tensor
  555. else:
  556. tensor_dict[key] = value
  557. for async_handle in async_handles:
  558. async_handle.wait()
  559. return tensor_dict
  560. def send_tensor_dict(
  561. self,
  562. tensor_dict: Dict[str, Union[torch.Tensor, Any]],
  563. dst: Optional[int] = None,
  564. all_gather_group: Optional["GroupCoordinator"] = None,
  565. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  566. """Send the input tensor dictionary.
  567. NOTE: `dst` is the local rank of the source rank.
  568. """
  569. # Bypass the function if we are using only 1 GPU.
  570. if not torch.distributed.is_initialized() or self.world_size == 1:
  571. return tensor_dict
  572. all_gather_size = (1 if all_gather_group is None else
  573. all_gather_group.world_size)
  574. all_gather_rank = (0 if all_gather_group is None else
  575. all_gather_group.rank_in_group)
  576. group = self.device_group
  577. metadata_group = self.cpu_group
  578. if dst is None:
  579. dst = (self.rank_in_group + 1) % self.world_size
  580. assert dst < self.world_size, f"Invalid dst rank ({dst})"
  581. metadata_list: List[Tuple[Any, Any]] = []
  582. assert isinstance(
  583. tensor_dict,
  584. dict), f"Expecting a dictionary, got {type(tensor_dict)}"
  585. metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
  586. # `metadata_list` lives in CPU memory.
  587. # `send_object_list` has serialization & deserialization,
  588. # all happening on CPU. Therefore, we can use the CPU group.
  589. self.send_object(metadata_list, dst=dst)
  590. for tensor in tensor_list:
  591. if tensor.numel() == 0:
  592. # Skip sending empty tensors.
  593. continue
  594. # send-allgather: send only a slice, then do allgather.
  595. if (all_gather_group is not None
  596. and tensor.numel() % all_gather_size == 0):
  597. tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
  598. if tensor.is_cpu:
  599. # use metadata_group for CPU tensors
  600. torch.distributed.send(tensor,
  601. dst=self.ranks[dst],
  602. group=metadata_group)
  603. else:
  604. # use group for GPU tensors
  605. torch.distributed.send(tensor,
  606. dst=self.ranks[dst],
  607. group=group)
  608. return None
  609. def recv_tensor_dict(
  610. self,
  611. src: Optional[int] = None,
  612. all_gather_group: Optional["GroupCoordinator"] = None,
  613. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  614. """Recv the input tensor dictionary.
  615. NOTE: `src` is the local rank of the source rank.
  616. """
  617. # Bypass the function if we are using only 1 GPU.
  618. if not torch.distributed.is_initialized() or self.world_size == 1:
  619. return None
  620. all_gather_size = (1 if all_gather_group is None else
  621. all_gather_group.world_size)
  622. all_gather_rank = (0 if all_gather_group is None else
  623. all_gather_group.rank_in_group)
  624. group = self.device_group
  625. metadata_group = self.cpu_group
  626. if src is None:
  627. src = (self.rank_in_group - 1) % self.world_size
  628. assert src < self.world_size, f"Invalid src rank ({src})"
  629. recv_metadata_list = self.recv_object(src=src)
  630. tensor_dict: Dict[str, Any] = {}
  631. for key, value in recv_metadata_list:
  632. if isinstance(value, TensorMetadata):
  633. tensor = torch.empty(value.size,
  634. dtype=value.dtype,
  635. device=value.device)
  636. if tensor.numel() == 0:
  637. # Skip broadcasting empty tensors.
  638. tensor_dict[key] = tensor
  639. continue
  640. # send-allgather: send only a slice, then do allgather.
  641. use_all_gather = (all_gather_group is not None
  642. and tensor.numel() % all_gather_size == 0)
  643. if use_all_gather:
  644. orig_shape = tensor.shape
  645. tensor = tensor.reshape(all_gather_size,
  646. -1)[all_gather_rank]
  647. if tensor.is_cpu:
  648. # use metadata_group for CPU tensors
  649. torch.distributed.recv(tensor,
  650. src=self.ranks[src],
  651. group=metadata_group)
  652. else:
  653. # use group for GPU tensors
  654. torch.distributed.recv(tensor,
  655. src=self.ranks[src],
  656. group=group)
  657. if use_all_gather:
  658. # do the allgather
  659. tensor = all_gather_group.all_gather( # type: ignore
  660. tensor, dim=0)
  661. tensor = tensor.reshape(orig_shape)
  662. tensor_dict[key] = tensor
  663. else:
  664. tensor_dict[key] = value
  665. return tensor_dict
  666. def barrier(self):
  667. """Barrier synchronization among the group.
  668. NOTE: don't use `device_group` here! `barrier` in NCCL is
  669. terrible because it is internally a broadcast operation with
  670. secretly created GPU tensors. It is easy to mess up the current
  671. device. Use the CPU group instead.
  672. """
  673. torch.distributed.barrier(group=self.cpu_group)
  674. def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
  675. """Sends a tensor to the destination rank in a non-blocking way"""
  676. """NOTE: `dst` is the local rank of the destination rank."""
  677. if dst is None:
  678. dst = (self.rank_in_group + 1) % self.world_size
  679. pynccl_comm = self.pynccl_comm
  680. if pynccl_comm is not None and not pynccl_comm.disabled:
  681. pynccl_comm.send(tensor, dst)
  682. else:
  683. torch.distributed.send(tensor, self.ranks[dst], self.device_group)
  684. def recv(self,
  685. size: torch.Size,
  686. dtype: torch.dtype,
  687. src: Optional[int] = None) -> torch.Tensor:
  688. """Receives a tensor from the src rank."""
  689. """NOTE: `src` is the local rank of the destination rank."""
  690. if src is None:
  691. src = (self.rank_in_group - 1) % self.world_size
  692. tensor = torch.empty(size, dtype=dtype, device=self.device)
  693. pynccl_comm = self.pynccl_comm
  694. if pynccl_comm is not None and not pynccl_comm.disabled:
  695. pynccl_comm.recv(tensor, src)
  696. else:
  697. torch.distributed.recv(tensor, self.ranks[src], self.device_group)
  698. return tensor
  699. def destroy(self):
  700. if self.device_group is not None:
  701. torch.distributed.destroy_process_group(self.device_group)
  702. self.device_group = None
  703. if self.cpu_group is not None:
  704. torch.distributed.destroy_process_group(self.cpu_group)
  705. self.cpu_group = None
  706. if self.pynccl_comm is not None:
  707. self.pynccl_comm = None
  708. if self.ca_comm is not None:
  709. self.ca_comm = None
  710. if self.mq_broadcaster is not None:
  711. self.mq_broadcaster = None
  712. _WORLD: Optional[GroupCoordinator] = None
  713. def get_world_group() -> GroupCoordinator:
  714. assert _WORLD is not None, ("world group is not initialized")
  715. return _WORLD
  716. def init_world_group(ranks: List[int], local_rank: int,
  717. backend: str) -> GroupCoordinator:
  718. return GroupCoordinator(
  719. group_ranks=[ranks],
  720. local_rank=local_rank,
  721. torch_distributed_backend=backend,
  722. use_pynccl=False,
  723. use_custom_allreduce=False,
  724. use_tpu_communicator=False,
  725. group_name="world",
  726. )
  727. def init_model_parallel_group(
  728. group_ranks: List[List[int]],
  729. local_rank: int,
  730. backend: str,
  731. use_custom_allreduce: Optional[bool] = None,
  732. use_message_queue_broadcaster: bool = False,
  733. group_name: Optional[str] = None,
  734. ) -> GroupCoordinator:
  735. if use_custom_allreduce is None:
  736. use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
  737. return GroupCoordinator(
  738. group_ranks=group_ranks,
  739. local_rank=local_rank,
  740. torch_distributed_backend=backend,
  741. use_pynccl=True,
  742. use_custom_allreduce=use_custom_allreduce,
  743. use_tpu_communicator=True,
  744. use_message_queue_broadcaster=use_message_queue_broadcaster,
  745. group_name=group_name,
  746. )
  747. _TP: Optional[GroupCoordinator] = None
  748. def get_tp_group() -> GroupCoordinator:
  749. assert _TP is not None, ("tensor model parallel group is not initialized")
  750. return _TP
  751. # kept for backward compatibility
  752. get_tensor_model_parallel_group = get_tp_group
  753. _PP: Optional[GroupCoordinator] = None
  754. def get_pp_group() -> GroupCoordinator:
  755. assert _PP is not None, (
  756. "pipeline model parallel group is not initialized")
  757. return _PP
  758. # kept for backward compatibility
  759. get_pipeline_model_parallel_group = get_pp_group
  760. @contextmanager
  761. def graph_capture():
  762. """
  763. `graph_capture` is a context manager which should surround the code that
  764. is capturing the CUDA graph. Its main purpose is to ensure that the
  765. some operations will be run after the graph is captured, before the graph
  766. is replayed. It returns a `GraphCaptureContext` object which contains the
  767. necessary data for the graph capture. Currently, it only contains the
  768. stream that the graph capture is running on. This stream is set to the
  769. current CUDA stream when the context manager is entered and reset to the
  770. default stream when the context manager is exited. This is to ensure that
  771. the graph capture is running on a separate stream from the default stream,
  772. in order to explicitly distinguish the kernels to capture
  773. from other kernels possibly launched on background in the default stream.
  774. """
  775. with get_tp_group().graph_capture() as context, get_pp_group(
  776. ).graph_capture(context):
  777. yield context
  778. _ENABLE_CUSTOM_ALL_REDUCE = True
  779. def set_custom_all_reduce(enable: bool):
  780. global _ENABLE_CUSTOM_ALL_REDUCE
  781. _ENABLE_CUSTOM_ALL_REDUCE = enable
  782. def init_distributed_environment(
  783. world_size: int = -1,
  784. rank: int = -1,
  785. distributed_init_method: str = "env://",
  786. local_rank: int = -1,
  787. backend: str = "nccl",
  788. ):
  789. logger.debug(
  790. f"world_size={world_size} rank={rank} local_rank={local_rank} "
  791. f"distributed_init_method={distributed_init_method} backend={backend}")
  792. if not torch.distributed.is_initialized():
  793. assert distributed_init_method is not None, (
  794. "distributed_init_method must be provided when initializing "
  795. "distributed environment")
  796. if sys.platform.startswith("win32") and distributed_init_method.startswith("tcp://"):
  797. distributed_init_method += "?use_libuv=0"
  798. backend = "gloo"
  799. # this backend is used for WORLD
  800. torch.distributed.init_process_group(
  801. backend=backend,
  802. init_method=distributed_init_method,
  803. world_size=world_size,
  804. rank=rank)
  805. # set the local rank
  806. # local_rank is not available in torch ProcessGroup,
  807. # see https://github.com/pytorch/pytorch/issues/122816
  808. if local_rank == -1:
  809. # local rank not set, this usually happens in single-node
  810. # setting, where we can use rank as local rank
  811. if distributed_init_method == "env://":
  812. local_rank = envs.LOCAL_RANK
  813. else:
  814. local_rank = rank
  815. global _WORLD
  816. if _WORLD is None:
  817. ranks = list(range(torch.distributed.get_world_size()))
  818. _WORLD = init_world_group(ranks, local_rank, backend)
  819. else:
  820. assert _WORLD.world_size == torch.distributed.get_world_size(), (
  821. "world group already initialized with a different world size")
  822. def initialize_model_parallel(
  823. tensor_model_parallel_size: int = 1,
  824. pipeline_model_parallel_size: int = 1,
  825. backend: Optional[str] = None,
  826. ) -> None:
  827. """
  828. Initialize model parallel groups.
  829. Arguments:
  830. tensor_model_parallel_size: number of GPUs used for tensor model
  831. parallelism.
  832. pipeline_model_parallel_size: number of GPUs used for pipeline model
  833. parallelism.
  834. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  835. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  836. the model pipeline. The present function will
  837. create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
  838. 4 tensor model-parallel groups:
  839. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  840. 2 pipeline model-parallel groups:
  841. [g0, g2, g4, g6], [g1, g3, g5, g7]
  842. Note that for efficiency, the caller should make sure adjacent ranks
  843. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  844. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  845. ranks 8 to 15 belong to the second box.
  846. """
  847. # Get world size and rank. Ensure some consistencies.
  848. assert torch.distributed.is_initialized()
  849. world_size: int = torch.distributed.get_world_size()
  850. backend = backend or torch.distributed.get_backend(
  851. get_world_group().device_group)
  852. if (world_size !=
  853. tensor_model_parallel_size * pipeline_model_parallel_size):
  854. raise RuntimeError(
  855. f"world_size ({world_size}) is not equal to "
  856. f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
  857. f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
  858. # Build the tensor model-parallel groups.
  859. num_tensor_model_parallel_groups: int = (world_size //
  860. tensor_model_parallel_size)
  861. global _TP
  862. assert _TP is None, ("tensor model parallel group is already initialized")
  863. group_ranks = []
  864. for i in range(num_tensor_model_parallel_groups):
  865. ranks = list(
  866. range(i * tensor_model_parallel_size,
  867. (i + 1) * tensor_model_parallel_size))
  868. group_ranks.append(ranks)
  869. # message queue broadcaster is only used in tensor model parallel group
  870. _TP = init_model_parallel_group(group_ranks,
  871. get_world_group().local_rank,
  872. backend,
  873. use_message_queue_broadcaster=True,
  874. group_name="tp")
  875. # Build the pipeline model-parallel groups.
  876. num_pipeline_model_parallel_groups: int = (world_size //
  877. pipeline_model_parallel_size)
  878. global _PP
  879. assert _PP is None, (
  880. "pipeline model parallel group is already initialized")
  881. group_ranks = []
  882. for i in range(num_pipeline_model_parallel_groups):
  883. ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
  884. group_ranks.append(ranks)
  885. # pipeline parallel does not need custom allreduce
  886. _PP = init_model_parallel_group(group_ranks,
  887. get_world_group().local_rank,
  888. backend,
  889. use_custom_allreduce=False,
  890. group_name="pp")
  891. def ensure_model_parallel_initialized(
  892. tensor_model_parallel_size: int,
  893. pipeline_model_parallel_size: int,
  894. backend: Optional[str] = None,
  895. ) -> None:
  896. """Helper to initialize model parallel groups if they are not initialized,
  897. or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
  898. values if the model parallel groups are initialized.
  899. """
  900. backend = backend or torch.distributed.get_backend(
  901. get_world_group().device_group)
  902. if not model_parallel_is_initialized():
  903. initialize_model_parallel(tensor_model_parallel_size,
  904. pipeline_model_parallel_size, backend)
  905. return
  906. assert (
  907. get_tensor_model_parallel_world_size() == tensor_model_parallel_size
  908. ), ("tensor parallel group already initialized, but of unexpected size: "
  909. f"{get_tensor_model_parallel_world_size()=} vs. "
  910. f"{tensor_model_parallel_size=}")
  911. pp_world_size = get_pp_group().world_size
  912. assert (pp_world_size == pipeline_model_parallel_size), (
  913. "pipeline parallel group already initialized, but of unexpected size: "
  914. f"{pp_world_size=} vs. "
  915. f"{pipeline_model_parallel_size=}")
  916. def model_parallel_is_initialized():
  917. """Check if tensor and pipeline parallel groups are initialized."""
  918. return (_TP is not None and _PP is not None)
  919. _TP_STATE_PATCHED = False
  920. @contextmanager
  921. def patch_tensor_parallel_group(tp_group: GroupCoordinator):
  922. """Patch the tp group temporarily until this function ends.
  923. This method is for draft workers of speculative decoding to run draft model
  924. with different tp degree from that of target model workers.
  925. Args:
  926. tp_group (GroupCoordinator): the tp group coordinator
  927. """
  928. global _TP_STATE_PATCHED
  929. assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
  930. _TP_STATE_PATCHED = True
  931. old_tp_group = get_tp_group()
  932. global _TP
  933. _TP = tp_group
  934. try:
  935. yield
  936. finally:
  937. # restore the original state
  938. _TP_STATE_PATCHED = False
  939. _TP = old_tp_group
  940. def get_tensor_model_parallel_world_size():
  941. """Return world size for the tensor model parallel group."""
  942. return get_tp_group().world_size
  943. def get_tensor_model_parallel_rank():
  944. """Return my rank for the tensor model parallel group."""
  945. return get_tp_group().rank_in_group
  946. def destroy_model_parallel():
  947. """Set the groups to none and destroy them."""
  948. global _TP
  949. if _TP:
  950. _TP.destroy()
  951. _TP = None
  952. global _PP
  953. if _PP:
  954. _PP.destroy()
  955. _PP = None
  956. def destroy_distributed_environment():
  957. global _WORLD
  958. if _WORLD:
  959. _WORLD.destroy()
  960. _WORLD = None
  961. if torch.distributed.is_initialized():
  962. torch.distributed.destroy_process_group()
  963. def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
  964. """
  965. This is a collective operation that returns if each rank is in the same node
  966. as the source rank. It tests if processes are attached to the same
  967. memory system (shared access to shared memory).
  968. """
  969. assert torch.distributed.get_backend(
  970. pg) != torch.distributed.Backend.NCCL, (
  971. "in_the_same_node_as should be tested with a non-NCCL group.")
  972. # local rank inside the group
  973. rank = torch.distributed.get_rank(group=pg)
  974. world_size = torch.distributed.get_world_size(group=pg)
  975. # local tensor in each process to store the result
  976. is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
  977. # global ranks of the processes in the group
  978. ranks = torch.distributed.get_process_group_ranks(pg)
  979. magic_message = b"magic_message"
  980. shm = None
  981. try:
  982. with contextlib.suppress(OSError):
  983. if rank == source_rank:
  984. # create a shared memory segment
  985. shm = shared_memory.SharedMemory(create=True, size=128)
  986. shm.buf[:len(magic_message)] = magic_message
  987. torch.distributed.broadcast_object_list([shm.name],
  988. src=ranks[source_rank],
  989. group=pg)
  990. is_in_the_same_node[rank] = 1
  991. else:
  992. # try to open the shared memory segment
  993. recv = [None]
  994. torch.distributed.broadcast_object_list(recv,
  995. src=ranks[source_rank],
  996. group=pg)
  997. name = recv[0]
  998. # fix to https://stackoverflow.com/q/62748654/9191338
  999. # Python incorrectly tracks shared memory even if it is not
  1000. # created by the process. The following patch is a workaround.
  1001. with patch("multiprocessing.resource_tracker.register",
  1002. lambda *args, **kwargs: None):
  1003. shm = shared_memory.SharedMemory(name=name)
  1004. if shm.buf[:len(magic_message)] == magic_message:
  1005. is_in_the_same_node[rank] = 1
  1006. except Exception as e:
  1007. logger.error(f"Error ignored in is_in_the_same_node: {e}")
  1008. finally:
  1009. if shm:
  1010. shm.close()
  1011. torch.distributed.barrier(group=pg)
  1012. # clean up the shared memory segment
  1013. with contextlib.suppress(OSError):
  1014. if rank == source_rank and shm:
  1015. shm.unlink()
  1016. torch.distributed.all_reduce(is_in_the_same_node, group=pg)
  1017. return [x == 1 for x in is_in_the_same_node.tolist()]
  1018. def get_current_tp_rank_partition_offset(total_size: int,
  1019. tp_rank: Optional[int] = None,
  1020. tp_size: Optional[int] = None,
  1021. multiple_of: int = 1) -> int:
  1022. if tp_rank is None:
  1023. tp_rank = get_tensor_model_parallel_rank()
  1024. if tp_size is None:
  1025. tp_size = get_tensor_model_parallel_world_size()
  1026. assert total_size % multiple_of == 0
  1027. total_size = total_size // multiple_of
  1028. return ((total_size // tp_size) * tp_rank +
  1029. min(total_size % tp_size, tp_rank)) * multiple_of
  1030. def get_current_tp_rank_partition_size(total_size: int,
  1031. tp_rank: Optional[int] = None,
  1032. tp_size: Optional[int] = None,
  1033. multiple_of: int = 1) -> int:
  1034. if tp_rank is None:
  1035. tp_rank = get_tensor_model_parallel_rank()
  1036. if tp_size is None:
  1037. tp_size = get_tensor_model_parallel_world_size()
  1038. assert total_size % multiple_of == 0
  1039. total_size = total_size // multiple_of
  1040. return ((total_size // tp_size) +
  1041. (total_size % tp_size > tp_rank)) * multiple_of