parallel_state.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. """Aphrodite distributed state.
  7. It takes over the control of the distributed environment from PyTorch.
  8. The typical workflow is:
  9. - call `init_distributed_environment` to initialize the distributed environment.
  10. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
  11. initialize the model parallel groups.
  12. - any code dealing with the distributed stuff
  13. - call `destroy_model_parallel` to destroy the model parallel groups.
  14. - call `destroy_distributed_environment` to destroy the distributed environment.
  15. If you only need to use the distributed environment without model/pipeline
  16. parallelism, you can skip the model parallel initialization and destruction
  17. steps.
  18. """
  19. import contextlib
  20. import pickle
  21. import sys
  22. import weakref
  23. from collections import namedtuple
  24. from contextlib import contextmanager, nullcontext
  25. from dataclasses import dataclass
  26. from multiprocessing import shared_memory
  27. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  28. from unittest.mock import patch
  29. import torch
  30. import torch.distributed
  31. from loguru import logger
  32. from torch.distributed import Backend, ProcessGroup
  33. import aphrodite.common.envs as envs
  34. @dataclass
  35. class GraphCaptureContext:
  36. stream: torch.cuda.Stream
  37. TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
  38. def _split_tensor_dict(
  39. tensor_dict: Dict[str, Union[torch.Tensor, Any]]
  40. ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
  41. """Split the tensor dictionary into two parts:
  42. 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
  43. by its metadata.
  44. 2. A list of tensors.
  45. """
  46. metadata_list: List[Tuple[str, Any]] = []
  47. tensor_list: List[torch.Tensor] = []
  48. for key, value in tensor_dict.items():
  49. if isinstance(value, torch.Tensor):
  50. # Note: we cannot use `value.device` here,
  51. # because it contains not only the device type but also the device
  52. # index (e.g. "cuda:0"). We only need the device type.
  53. # receiving side will set the device index.
  54. device = value.device.type
  55. metadata_list.append(
  56. (key, TensorMetadata(device, value.dtype, value.size())))
  57. tensor_list.append(value)
  58. else:
  59. metadata_list.append((key, value))
  60. return metadata_list, tensor_list
  61. _group_name_counter: Dict[str, int] = {}
  62. def _get_unique_name(name: str) -> str:
  63. """Get a unique name for the group.
  64. Example:
  65. _get_unique_name("tp") -> "tp:0"
  66. _get_unique_name("tp") -> "tp:1"
  67. """
  68. if name not in _group_name_counter:
  69. _group_name_counter[name] = 0
  70. newname = f"{name}:{_group_name_counter[name]}"
  71. _group_name_counter[name] += 1
  72. return newname
  73. _groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
  74. def _register_group(group: "GroupCoordinator") -> None:
  75. # looks like Python 3.8 does not understand `ReferenceType`
  76. _groups[group.unique_name] = weakref.ref(group) # type: ignore
  77. # Some backends use pytorch version < 2.4.0 which doesn't
  78. # support `torch.library.custom_op`.
  79. def supports_custom_op() -> bool:
  80. return hasattr(torch.library, "custom_op")
  81. if supports_custom_op():
  82. @torch.library.custom_op("aphrodite::inplace_all_reduce",
  83. mutates_args=["tensor"])
  84. def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
  85. assert group_name in _groups, f"Group {group_name} is not found."
  86. group = _groups[group_name]()
  87. if group is None:
  88. raise ValueError(f"Group {group_name} is destroyed.")
  89. group._all_reduce(tensor)
  90. @inplace_all_reduce.register_fake
  91. def _(tensor: torch.Tensor, group_name: str) -> None:
  92. return
  93. @torch.library.custom_op("aphrodite::outplace_all_reduce", mutates_args=[])
  94. def outplace_all_reduce(tensor: torch.Tensor,
  95. group_name: str) -> torch.Tensor:
  96. assert group_name in _groups, f"Group {group_name} is not found."
  97. group = _groups[group_name]()
  98. if group is None:
  99. raise ValueError(f"Group {group_name} is destroyed.")
  100. return group._all_reduce(tensor)
  101. @outplace_all_reduce.register_fake
  102. def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
  103. return torch.empty_like(tensor)
  104. class GroupCoordinator:
  105. """
  106. PyTorch ProcessGroup wrapper for a group of processes.
  107. PyTorch ProcessGroup is bound to one specific communication backend,
  108. e.g. NCCL, Gloo, MPI, etc.
  109. GroupCoordinator takes charge of all the communication operations among
  110. the processes in the group. It can route the communication to
  111. a specific implementation (e.g. switch allreduce implementation
  112. based on the tensor size and cuda graph mode).
  113. """
  114. # available attributes:
  115. rank: int # global rank
  116. ranks: List[int] # global ranks in the group
  117. world_size: int # size of the group
  118. # difference between `local_rank` and `rank_in_group`:
  119. # if we have a group of size 4 across two nodes:
  120. # Process | Node | Rank | Local Rank | Rank in Group
  121. # 0 | 0 | 0 | 0 | 0
  122. # 1 | 0 | 1 | 1 | 1
  123. # 2 | 1 | 2 | 0 | 2
  124. # 3 | 1 | 3 | 1 | 3
  125. local_rank: int # local rank used to assign devices
  126. rank_in_group: int # rank inside the group
  127. cpu_group: ProcessGroup # group for CPU communication
  128. device_group: ProcessGroup # group for device communication
  129. use_pynccl: bool # a hint of whether to use PyNccl
  130. use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
  131. # communicators are only created for world size > 1
  132. pynccl_comm: Optional[Any] # PyNccl communicator
  133. ca_comm: Optional[Any] # Custom allreduce communicator
  134. mq_broadcaster: Optional[Any] # shared memory broadcaster
  135. def __init__(
  136. self,
  137. group_ranks: List[List[int]],
  138. local_rank: int,
  139. torch_distributed_backend: Union[str, Backend],
  140. use_pynccl: bool,
  141. use_custom_allreduce: bool,
  142. use_tpu_communicator: bool,
  143. use_message_queue_broadcaster: bool = False,
  144. group_name: Optional[str] = None,
  145. ):
  146. group_name = group_name or "anonymous"
  147. self.unique_name = _get_unique_name(group_name)
  148. _register_group(self)
  149. self.rank = torch.distributed.get_rank()
  150. self.local_rank = local_rank
  151. self.device_group = None
  152. self.cpu_group = None
  153. for ranks in group_ranks:
  154. device_group = torch.distributed.new_group(
  155. ranks, backend=torch_distributed_backend)
  156. # a group with `gloo` backend, to allow direct coordination between
  157. # processes through the CPU.
  158. cpu_group = torch.distributed.new_group(ranks, backend="gloo")
  159. if self.rank in ranks:
  160. self.ranks = ranks
  161. self.world_size = len(ranks)
  162. self.rank_in_group = ranks.index(self.rank)
  163. self.device_group = device_group
  164. self.cpu_group = cpu_group
  165. assert self.cpu_group is not None
  166. assert self.device_group is not None
  167. if torch.cuda.is_available():
  168. self.device = torch.device(f"cuda:{local_rank}")
  169. else:
  170. self.device = torch.device("cpu")
  171. self.use_pynccl = use_pynccl
  172. self.use_custom_allreduce = use_custom_allreduce
  173. self.use_tpu_communicator = use_tpu_communicator
  174. # lazy import to avoid documentation build error
  175. from aphrodite.distributed.device_communicators.custom_all_reduce import ( # noqa: E501
  176. CustomAllreduce)
  177. from aphrodite.distributed.device_communicators.pynccl import (
  178. PyNcclCommunicator)
  179. self.pynccl_comm: Optional[PyNcclCommunicator] = None
  180. if use_pynccl and self.world_size > 1:
  181. self.pynccl_comm = PyNcclCommunicator(
  182. group=self.cpu_group,
  183. device=self.device,
  184. )
  185. self.ca_comm: Optional[CustomAllreduce] = None
  186. if use_custom_allreduce and self.world_size > 1:
  187. # Initialize a custom fast all-reduce implementation.
  188. self.ca_comm = CustomAllreduce(
  189. group=self.cpu_group,
  190. device=self.device,
  191. )
  192. from aphrodite.distributed.device_communicators.tpu_communicator import ( # noqa: E501
  193. TpuCommunicator)
  194. self.tpu_communicator: Optional[TpuCommunicator] = None
  195. if use_tpu_communicator and self.world_size > 1:
  196. self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
  197. from aphrodite.distributed.device_communicators.shm_broadcast import (
  198. MessageQueue)
  199. self.mq_broadcaster: Optional[MessageQueue] = None
  200. if use_message_queue_broadcaster and self.world_size > 1:
  201. self.mq_broadcaster = MessageQueue.create_from_process_group(
  202. self.cpu_group, 1 << 22, 6)
  203. @property
  204. def first_rank(self):
  205. """Return the global rank of the first process in the group"""
  206. return self.ranks[0]
  207. @property
  208. def last_rank(self):
  209. """Return the global rank of the last process in the group"""
  210. return self.ranks[-1]
  211. @property
  212. def is_first_rank(self):
  213. """Return whether the caller is the first process in the group"""
  214. return self.rank == self.first_rank
  215. @property
  216. def is_last_rank(self):
  217. """Return whether the caller is the last process in the group"""
  218. return self.rank == self.last_rank
  219. @property
  220. def next_rank(self):
  221. """Return the global rank of the process that follows the caller"""
  222. rank_in_group = self.rank_in_group
  223. world_size = self.world_size
  224. return self.ranks[(rank_in_group + 1) % world_size]
  225. @property
  226. def prev_rank(self):
  227. """Return the global rank of the process that precedes the caller"""
  228. rank_in_group = self.rank_in_group
  229. world_size = self.world_size
  230. return self.ranks[(rank_in_group - 1) % world_size]
  231. @contextmanager
  232. def graph_capture(
  233. self, graph_capture_context: Optional[GraphCaptureContext] = None):
  234. if graph_capture_context is None:
  235. stream = torch.cuda.Stream()
  236. graph_capture_context = GraphCaptureContext(stream)
  237. else:
  238. stream = graph_capture_context.stream
  239. ca_comm = self.ca_comm
  240. maybe_ca_context = nullcontext(
  241. ) if ca_comm is None else ca_comm.capture()
  242. # ensure all initialization operations complete before attempting to
  243. # capture the graph on another stream
  244. curr_stream = torch.cuda.current_stream()
  245. if curr_stream != stream:
  246. stream.wait_stream(curr_stream)
  247. with torch.cuda.stream(stream), maybe_ca_context:
  248. # In graph mode, we have to be very careful about the collective
  249. # operations. The current status is:
  250. # allreduce \ Mode | Eager | Graph |
  251. # --------------------------------------------
  252. # custom allreduce | enabled | enabled |
  253. # PyNccl | disabled| enabled |
  254. # torch.distributed | enabled | disabled|
  255. #
  256. # Note that custom allreduce will have a runtime check, if the
  257. # tensor size is too large, it will fallback to the next
  258. # available option.
  259. # In summary: When using CUDA graph, we use
  260. # either custom all-reduce kernel or pynccl. When not using
  261. # CUDA graph, we use either custom all-reduce kernel or
  262. # PyTorch NCCL. We always prioritize using custom all-reduce
  263. # kernel but fall back to PyTorch or pynccl if it is
  264. # disabled or not supported.
  265. pynccl_comm = self.pynccl_comm
  266. maybe_pynccl_context: Any
  267. if not pynccl_comm:
  268. maybe_pynccl_context = nullcontext()
  269. else:
  270. maybe_pynccl_context = pynccl_comm.change_state(
  271. enable=True, stream=torch.cuda.current_stream())
  272. with maybe_pynccl_context:
  273. yield graph_capture_context
  274. def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
  275. """
  276. User-facing all-reduce function before we actually call the
  277. all-reduce operation.
  278. We need this because Dynamo does not support passing an arbitrary
  279. object (`self` in this case) to a custom op. We need to pass the
  280. group name as a string, and then look up the group coordinator from
  281. the group name, dispatch the all-reduce operation to the group
  282. coordinator.
  283. In addition, PyTorch custom ops do not support mutation or returning
  284. a new tensor in the same op. So we need to figure out if the op is
  285. in-place or out-of-place ahead of time.
  286. """
  287. # Bypass the function if we are using only 1 GPU.
  288. if self.world_size == 1:
  289. return input_
  290. if not supports_custom_op():
  291. return self._all_reduce(input_)
  292. if self.tpu_communicator is not None and \
  293. not self.tpu_communicator.disabled:
  294. # TPU handles Dynamo with its own logic.
  295. return self._all_reduce(input_)
  296. if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
  297. return torch.ops.aphrodite.outplace_all_reduce(
  298. input_, group_name=self.unique_name)
  299. else:
  300. torch.ops.aphrodite.inplace_all_reduce(input_,
  301. group_name=self.unique_name)
  302. return input_
  303. def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
  304. """
  305. The actual all-reduce implementation.
  306. NOTE: This operation will be applied in-place or out-of-place.
  307. Always assume this function modifies its input, but use the return
  308. value as the output.
  309. """
  310. ca_comm = self.ca_comm
  311. # For TPUs, use TPU communicator.
  312. tpu_comm = self.tpu_communicator
  313. if tpu_comm is not None and not tpu_comm.disabled:
  314. return tpu_comm.all_reduce(input_)
  315. if ca_comm is not None:
  316. out = ca_comm.custom_all_reduce(input_)
  317. if out is not None:
  318. return out
  319. pynccl_comm = self.pynccl_comm
  320. if (pynccl_comm is not None and not pynccl_comm.disabled):
  321. pynccl_comm.all_reduce(input_)
  322. elif input_.is_cpu:
  323. import intel_extension_for_pytorch as ipex
  324. ipex.distributed.all_reduce(input_, group=self.device_group)
  325. else:
  326. torch.distributed.all_reduce(input_, group=self.device_group)
  327. return input_
  328. def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
  329. world_size = self.world_size
  330. # Bypass the function if we are using only 1 GPU.
  331. if world_size == 1:
  332. return input_
  333. assert -input_.dim() <= dim < input_.dim(), (
  334. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  335. # For TPUs, use TPU communicator.
  336. tpu_comm = self.tpu_communicator
  337. if tpu_comm is not None and not tpu_comm.disabled:
  338. return tpu_comm.all_gather(input_, dim)
  339. if dim < 0:
  340. # Convert negative dim to positive.
  341. dim += input_.dim()
  342. input_size = input_.size()
  343. # Allocate output tensor.
  344. output_tensor = torch.empty((world_size, ) + input_size,
  345. dtype=input_.dtype,
  346. device=input_.device)
  347. # All-gather.
  348. torch.distributed.all_gather_into_tensor(output_tensor,
  349. input_,
  350. group=self.device_group)
  351. # Reshape
  352. output_tensor = output_tensor.movedim(0, dim)
  353. output_tensor = output_tensor.reshape(input_size[:dim] +
  354. (world_size *
  355. input_size[dim], ) +
  356. input_size[dim + 1:])
  357. return output_tensor
  358. def gather(self,
  359. input_: torch.Tensor,
  360. dst: int = 0,
  361. dim: int = -1) -> Optional[torch.Tensor]:
  362. """
  363. NOTE: We assume that the input tensor is on the same device across
  364. all the ranks.
  365. NOTE: `dst` is the local rank of the destination rank.
  366. """
  367. world_size = self.world_size
  368. # Bypass the function if we are using only 1 GPU.
  369. if world_size == 1:
  370. return input_
  371. assert -input_.dim() <= dim < input_.dim(), (
  372. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  373. if dim < 0:
  374. # Convert negative dim to positive.
  375. dim += input_.dim()
  376. # Allocate output tensor.
  377. if self.rank_in_group == dst:
  378. gather_list = [torch.empty_like(input_) for _ in range(world_size)]
  379. else:
  380. gather_list = None
  381. # Gather.
  382. torch.distributed.gather(input_,
  383. gather_list,
  384. dst=self.ranks[dst],
  385. group=self.device_group)
  386. if self.rank_in_group == dst:
  387. output_tensor = torch.cat(gather_list, dim=dim)
  388. else:
  389. output_tensor = None
  390. return output_tensor
  391. def broadcast(self, input_: torch.Tensor, src: int = 0):
  392. """Broadcast the input tensor.
  393. NOTE: `src` is the local rank of the source rank.
  394. """
  395. assert src < self.world_size, f"Invalid src rank ({src})"
  396. # Bypass the function if we are using only 1 GPU.
  397. if self.world_size == 1:
  398. return input_
  399. # Broadcast.
  400. torch.distributed.broadcast(input_,
  401. src=self.ranks[src],
  402. group=self.device_group)
  403. return input_
  404. def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
  405. """Broadcast the input object.
  406. NOTE: `src` is the local rank of the source rank.
  407. """
  408. assert src < self.world_size, f"Invalid src rank ({src})"
  409. # Bypass the function if we are using only 1 GPU.
  410. if self.world_size == 1:
  411. return obj
  412. if self.mq_broadcaster is not None:
  413. assert src == 0, "Message queue broadcaster only supports src=0"
  414. return self.mq_broadcaster.broadcast_object(obj)
  415. if self.rank_in_group == src:
  416. torch.distributed.broadcast_object_list([obj],
  417. src=self.ranks[src],
  418. group=self.cpu_group)
  419. return obj
  420. else:
  421. recv = [None]
  422. torch.distributed.broadcast_object_list(recv,
  423. src=self.ranks[src],
  424. group=self.cpu_group)
  425. return recv[0]
  426. def broadcast_object_list(self,
  427. obj_list: List[Any],
  428. src: int = 0,
  429. group: Optional[ProcessGroup] = None):
  430. """Broadcast the input object list.
  431. NOTE: `src` is the local rank of the source rank.
  432. """
  433. assert src < self.world_size, f"Invalid src rank ({src})"
  434. # Bypass the function if we are using only 1 GPU.
  435. if self.world_size == 1:
  436. return obj_list
  437. # Broadcast.
  438. torch.distributed.broadcast_object_list(obj_list,
  439. src=self.ranks[src],
  440. group=self.device_group)
  441. return obj_list
  442. def send_object(self, obj: Any, dst: int) -> None:
  443. """Send the input object list to the destination rank."""
  444. """NOTE: `dst` is the local rank of the destination rank."""
  445. assert dst < self.world_size, f"Invalid dst rank ({dst})"
  446. assert dst != self.rank_in_group, (
  447. "Invalid destination rank. Destination rank is the same "
  448. "as the current rank.")
  449. # Serialize object to tensor and get the size as well
  450. object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
  451. size_tensor = torch.tensor([object_tensor.numel()],
  452. dtype=torch.long,
  453. device="cpu")
  454. # Send object size
  455. torch.distributed.send(size_tensor,
  456. dst=self.ranks[dst],
  457. group=self.cpu_group)
  458. # Send object
  459. torch.distributed.send(object_tensor,
  460. dst=self.ranks[dst],
  461. group=self.cpu_group)
  462. return None
  463. def recv_object(self, src: int) -> Any:
  464. """Receive the input object list from the source rank."""
  465. """NOTE: `src` is the local rank of the source rank."""
  466. assert src < self.world_size, f"Invalid src rank ({src})"
  467. assert src != self.rank_in_group, (
  468. "Invalid source rank. Source rank is the same as the current rank."
  469. )
  470. size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
  471. # Receive object size
  472. rank_size = torch.distributed.recv(size_tensor,
  473. src=self.ranks[src],
  474. group=self.cpu_group)
  475. # Tensor to receive serialized objects into.
  476. object_tensor = torch.empty( # type: ignore[call-overload]
  477. size_tensor.item(), # type: ignore[arg-type]
  478. dtype=torch.uint8,
  479. device="cpu")
  480. rank_object = torch.distributed.recv(object_tensor,
  481. src=self.ranks[src],
  482. group=self.cpu_group)
  483. assert rank_object == rank_size, (
  484. "Received object sender rank does not match the size sender rank.")
  485. obj = pickle.loads(object_tensor.numpy().tobytes())
  486. return obj
  487. def broadcast_tensor_dict(
  488. self,
  489. tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
  490. src: int = 0,
  491. group: Optional[ProcessGroup] = None,
  492. metadata_group: Optional[ProcessGroup] = None
  493. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  494. """Broadcast the input tensor dictionary.
  495. NOTE: `src` is the local rank of the source rank.
  496. """
  497. # Bypass the function if we are using only 1 GPU.
  498. if (not torch.distributed.is_initialized() or self.world_size == 1):
  499. return tensor_dict
  500. group = self.device_group
  501. metadata_group = self.cpu_group
  502. assert src < self.world_size, f"Invalid src rank ({src})"
  503. rank_in_group = self.rank_in_group
  504. if rank_in_group == src:
  505. metadata_list: List[Tuple[Any, Any]] = []
  506. assert isinstance(
  507. tensor_dict,
  508. dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
  509. metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
  510. # `metadata_list` lives in CPU memory.
  511. # `broadcast_object_list` has serialization & deserialization,
  512. # all happening on CPU. Therefore, we can use the CPU group.
  513. self.broadcast_object(metadata_list, src=src)
  514. async_handles = []
  515. for tensor in tensor_list:
  516. if tensor.numel() == 0:
  517. # Skip broadcasting empty tensors.
  518. continue
  519. if tensor.is_cpu:
  520. # use metadata_group for CPU tensors
  521. handle = torch.distributed.broadcast(tensor,
  522. src=self.ranks[src],
  523. group=metadata_group,
  524. async_op=True)
  525. else:
  526. # use group for GPU tensors
  527. handle = torch.distributed.broadcast(tensor,
  528. src=self.ranks[src],
  529. group=group,
  530. async_op=True)
  531. async_handles.append(handle)
  532. for async_handle in async_handles:
  533. async_handle.wait()
  534. else:
  535. metadata_list = self.broadcast_object(None, src=src)
  536. tensor_dict = {}
  537. async_handles = []
  538. for key, value in metadata_list:
  539. if isinstance(value, TensorMetadata):
  540. tensor = torch.empty(value.size,
  541. dtype=value.dtype,
  542. device=value.device)
  543. if tensor.numel() == 0:
  544. # Skip broadcasting empty tensors.
  545. tensor_dict[key] = tensor
  546. continue
  547. if tensor.is_cpu:
  548. # use metadata_group for CPU tensors
  549. handle = torch.distributed.broadcast(
  550. tensor,
  551. src=self.ranks[src],
  552. group=metadata_group,
  553. async_op=True)
  554. else:
  555. # use group for GPU tensors
  556. handle = torch.distributed.broadcast(
  557. tensor,
  558. src=self.ranks[src],
  559. group=group,
  560. async_op=True)
  561. async_handles.append(handle)
  562. tensor_dict[key] = tensor
  563. else:
  564. tensor_dict[key] = value
  565. for async_handle in async_handles:
  566. async_handle.wait()
  567. return tensor_dict
  568. def send_tensor_dict(
  569. self,
  570. tensor_dict: Dict[str, Union[torch.Tensor, Any]],
  571. dst: Optional[int] = None,
  572. all_gather_group: Optional["GroupCoordinator"] = None,
  573. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  574. """Send the input tensor dictionary.
  575. NOTE: `dst` is the local rank of the source rank.
  576. """
  577. # Bypass the function if we are using only 1 GPU.
  578. if not torch.distributed.is_initialized() or self.world_size == 1:
  579. return tensor_dict
  580. all_gather_size = (1 if all_gather_group is None else
  581. all_gather_group.world_size)
  582. all_gather_rank = (0 if all_gather_group is None else
  583. all_gather_group.rank_in_group)
  584. group = self.device_group
  585. metadata_group = self.cpu_group
  586. if dst is None:
  587. dst = (self.rank_in_group + 1) % self.world_size
  588. assert dst < self.world_size, f"Invalid dst rank ({dst})"
  589. metadata_list: List[Tuple[Any, Any]] = []
  590. assert isinstance(
  591. tensor_dict,
  592. dict), f"Expecting a dictionary, got {type(tensor_dict)}"
  593. metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
  594. # `metadata_list` lives in CPU memory.
  595. # `send_object_list` has serialization & deserialization,
  596. # all happening on CPU. Therefore, we can use the CPU group.
  597. self.send_object(metadata_list, dst=dst)
  598. for tensor in tensor_list:
  599. if tensor.numel() == 0:
  600. # Skip sending empty tensors.
  601. continue
  602. # send-allgather: send only a slice, then do allgather.
  603. if (all_gather_group is not None
  604. and tensor.numel() % all_gather_size == 0):
  605. tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
  606. if tensor.is_cpu:
  607. # use metadata_group for CPU tensors
  608. torch.distributed.send(tensor,
  609. dst=self.ranks[dst],
  610. group=metadata_group)
  611. else:
  612. # use group for GPU tensors
  613. torch.distributed.send(tensor,
  614. dst=self.ranks[dst],
  615. group=group)
  616. return None
  617. def recv_tensor_dict(
  618. self,
  619. src: Optional[int] = None,
  620. all_gather_group: Optional["GroupCoordinator"] = None,
  621. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  622. """Recv the input tensor dictionary.
  623. NOTE: `src` is the local rank of the source rank.
  624. """
  625. # Bypass the function if we are using only 1 GPU.
  626. if not torch.distributed.is_initialized() or self.world_size == 1:
  627. return None
  628. all_gather_size = (1 if all_gather_group is None else
  629. all_gather_group.world_size)
  630. all_gather_rank = (0 if all_gather_group is None else
  631. all_gather_group.rank_in_group)
  632. group = self.device_group
  633. metadata_group = self.cpu_group
  634. if src is None:
  635. src = (self.rank_in_group - 1) % self.world_size
  636. assert src < self.world_size, f"Invalid src rank ({src})"
  637. recv_metadata_list = self.recv_object(src=src)
  638. tensor_dict: Dict[str, Any] = {}
  639. for key, value in recv_metadata_list:
  640. if isinstance(value, TensorMetadata):
  641. tensor = torch.empty(value.size,
  642. dtype=value.dtype,
  643. device=value.device)
  644. if tensor.numel() == 0:
  645. # Skip broadcasting empty tensors.
  646. tensor_dict[key] = tensor
  647. continue
  648. # send-allgather: send only a slice, then do allgather.
  649. use_all_gather = (all_gather_group is not None
  650. and tensor.numel() % all_gather_size == 0)
  651. if use_all_gather:
  652. orig_shape = tensor.shape
  653. tensor = tensor.reshape(all_gather_size,
  654. -1)[all_gather_rank]
  655. if tensor.is_cpu:
  656. # use metadata_group for CPU tensors
  657. torch.distributed.recv(tensor,
  658. src=self.ranks[src],
  659. group=metadata_group)
  660. else:
  661. # use group for GPU tensors
  662. torch.distributed.recv(tensor,
  663. src=self.ranks[src],
  664. group=group)
  665. if use_all_gather:
  666. # do the allgather
  667. tensor = all_gather_group.all_gather( # type: ignore
  668. tensor, dim=0)
  669. tensor = tensor.reshape(orig_shape)
  670. tensor_dict[key] = tensor
  671. else:
  672. tensor_dict[key] = value
  673. return tensor_dict
  674. def barrier(self):
  675. """Barrier synchronization among the group.
  676. NOTE: don't use `device_group` here! `barrier` in NCCL is
  677. terrible because it is internally a broadcast operation with
  678. secretly created GPU tensors. It is easy to mess up the current
  679. device. Use the CPU group instead.
  680. """
  681. torch.distributed.barrier(group=self.cpu_group)
  682. def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
  683. """Sends a tensor to the destination rank in a non-blocking way"""
  684. """NOTE: `dst` is the local rank of the destination rank."""
  685. if dst is None:
  686. dst = (self.rank_in_group + 1) % self.world_size
  687. pynccl_comm = self.pynccl_comm
  688. if pynccl_comm is not None and not pynccl_comm.disabled:
  689. pynccl_comm.send(tensor, dst)
  690. else:
  691. torch.distributed.send(tensor, self.ranks[dst], self.device_group)
  692. def recv(self,
  693. size: torch.Size,
  694. dtype: torch.dtype,
  695. src: Optional[int] = None) -> torch.Tensor:
  696. """Receives a tensor from the src rank."""
  697. """NOTE: `src` is the local rank of the destination rank."""
  698. if src is None:
  699. src = (self.rank_in_group - 1) % self.world_size
  700. tensor = torch.empty(size, dtype=dtype, device=self.device)
  701. pynccl_comm = self.pynccl_comm
  702. if pynccl_comm is not None and not pynccl_comm.disabled:
  703. pynccl_comm.recv(tensor, src)
  704. else:
  705. torch.distributed.recv(tensor, self.ranks[src], self.device_group)
  706. return tensor
  707. def destroy(self):
  708. if self.device_group is not None:
  709. torch.distributed.destroy_process_group(self.device_group)
  710. self.device_group = None
  711. if self.cpu_group is not None:
  712. torch.distributed.destroy_process_group(self.cpu_group)
  713. self.cpu_group = None
  714. if self.pynccl_comm is not None:
  715. self.pynccl_comm = None
  716. if self.ca_comm is not None:
  717. self.ca_comm = None
  718. if self.mq_broadcaster is not None:
  719. self.mq_broadcaster = None
  720. _WORLD: Optional[GroupCoordinator] = None
  721. def get_world_group() -> GroupCoordinator:
  722. assert _WORLD is not None, ("world group is not initialized")
  723. return _WORLD
  724. def init_world_group(ranks: List[int], local_rank: int,
  725. backend: str) -> GroupCoordinator:
  726. return GroupCoordinator(
  727. group_ranks=[ranks],
  728. local_rank=local_rank,
  729. torch_distributed_backend=backend,
  730. use_pynccl=False,
  731. use_custom_allreduce=False,
  732. use_tpu_communicator=False,
  733. group_name="world",
  734. )
  735. def init_model_parallel_group(
  736. group_ranks: List[List[int]],
  737. local_rank: int,
  738. backend: str,
  739. use_custom_allreduce: Optional[bool] = None,
  740. use_message_queue_broadcaster: bool = False,
  741. group_name: Optional[str] = None,
  742. ) -> GroupCoordinator:
  743. if use_custom_allreduce is None:
  744. use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
  745. return GroupCoordinator(
  746. group_ranks=group_ranks,
  747. local_rank=local_rank,
  748. torch_distributed_backend=backend,
  749. use_pynccl=True,
  750. use_custom_allreduce=use_custom_allreduce,
  751. use_tpu_communicator=True,
  752. use_message_queue_broadcaster=use_message_queue_broadcaster,
  753. group_name=group_name,
  754. )
  755. _TP: Optional[GroupCoordinator] = None
  756. def get_tp_group() -> GroupCoordinator:
  757. assert _TP is not None, ("tensor model parallel group is not initialized")
  758. return _TP
  759. # kept for backward compatibility
  760. get_tensor_model_parallel_group = get_tp_group
  761. _PP: Optional[GroupCoordinator] = None
  762. def get_pp_group() -> GroupCoordinator:
  763. assert _PP is not None, (
  764. "pipeline model parallel group is not initialized")
  765. return _PP
  766. # kept for backward compatibility
  767. get_pipeline_model_parallel_group = get_pp_group
  768. @contextmanager
  769. def graph_capture():
  770. """
  771. `graph_capture` is a context manager which should surround the code that
  772. is capturing the CUDA graph. Its main purpose is to ensure that the
  773. some operations will be run after the graph is captured, before the graph
  774. is replayed. It returns a `GraphCaptureContext` object which contains the
  775. necessary data for the graph capture. Currently, it only contains the
  776. stream that the graph capture is running on. This stream is set to the
  777. current CUDA stream when the context manager is entered and reset to the
  778. default stream when the context manager is exited. This is to ensure that
  779. the graph capture is running on a separate stream from the default stream,
  780. in order to explicitly distinguish the kernels to capture
  781. from other kernels possibly launched on background in the default stream.
  782. """
  783. with get_tp_group().graph_capture() as context, get_pp_group(
  784. ).graph_capture(context):
  785. yield context
  786. _ENABLE_CUSTOM_ALL_REDUCE = True
  787. def set_custom_all_reduce(enable: bool):
  788. global _ENABLE_CUSTOM_ALL_REDUCE
  789. _ENABLE_CUSTOM_ALL_REDUCE = enable
  790. def init_distributed_environment(
  791. world_size: int = -1,
  792. rank: int = -1,
  793. distributed_init_method: str = "env://",
  794. local_rank: int = -1,
  795. backend: str = "nccl",
  796. ):
  797. logger.debug(
  798. f"world_size={world_size} rank={rank} local_rank={local_rank} "
  799. f"distributed_init_method={distributed_init_method} backend={backend}")
  800. if not torch.distributed.is_initialized():
  801. assert distributed_init_method is not None, (
  802. "distributed_init_method must be provided when initializing "
  803. "distributed environment")
  804. if sys.platform.startswith("win32") and distributed_init_method.startswith("tcp://"):
  805. distributed_init_method += "?use_libuv=0"
  806. backend = "gloo"
  807. # this backend is used for WORLD
  808. torch.distributed.init_process_group(
  809. backend=backend,
  810. init_method=distributed_init_method,
  811. world_size=world_size,
  812. rank=rank)
  813. # set the local rank
  814. # local_rank is not available in torch ProcessGroup,
  815. # see https://github.com/pytorch/pytorch/issues/122816
  816. if local_rank == -1:
  817. # local rank not set, this usually happens in single-node
  818. # setting, where we can use rank as local rank
  819. if distributed_init_method == "env://":
  820. local_rank = envs.LOCAL_RANK
  821. else:
  822. local_rank = rank
  823. global _WORLD
  824. if _WORLD is None:
  825. ranks = list(range(torch.distributed.get_world_size()))
  826. _WORLD = init_world_group(ranks, local_rank, backend)
  827. else:
  828. assert _WORLD.world_size == torch.distributed.get_world_size(), (
  829. "world group already initialized with a different world size")
  830. def initialize_model_parallel(
  831. tensor_model_parallel_size: int = 1,
  832. pipeline_model_parallel_size: int = 1,
  833. backend: Optional[str] = None,
  834. ) -> None:
  835. """
  836. Initialize model parallel groups.
  837. Arguments:
  838. tensor_model_parallel_size: number of GPUs used for tensor model
  839. parallelism.
  840. pipeline_model_parallel_size: number of GPUs used for pipeline model
  841. parallelism.
  842. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  843. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  844. the model pipeline. The present function will
  845. create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
  846. 4 tensor model-parallel groups:
  847. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  848. 2 pipeline model-parallel groups:
  849. [g0, g2, g4, g6], [g1, g3, g5, g7]
  850. Note that for efficiency, the caller should make sure adjacent ranks
  851. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  852. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  853. ranks 8 to 15 belong to the second box.
  854. """
  855. # Get world size and rank. Ensure some consistencies.
  856. assert torch.distributed.is_initialized()
  857. world_size: int = torch.distributed.get_world_size()
  858. backend = backend or torch.distributed.get_backend(
  859. get_world_group().device_group)
  860. if (world_size !=
  861. tensor_model_parallel_size * pipeline_model_parallel_size):
  862. raise RuntimeError(
  863. f"world_size ({world_size}) is not equal to "
  864. f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
  865. f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
  866. # Build the tensor model-parallel groups.
  867. num_tensor_model_parallel_groups: int = (world_size //
  868. tensor_model_parallel_size)
  869. global _TP
  870. assert _TP is None, ("tensor model parallel group is already initialized")
  871. group_ranks = []
  872. for i in range(num_tensor_model_parallel_groups):
  873. ranks = list(
  874. range(i * tensor_model_parallel_size,
  875. (i + 1) * tensor_model_parallel_size))
  876. group_ranks.append(ranks)
  877. # message queue broadcaster is only used in tensor model parallel group
  878. _TP = init_model_parallel_group(group_ranks,
  879. get_world_group().local_rank,
  880. backend,
  881. use_message_queue_broadcaster=True,
  882. group_name="tp")
  883. # Build the pipeline model-parallel groups.
  884. num_pipeline_model_parallel_groups: int = (world_size //
  885. pipeline_model_parallel_size)
  886. global _PP
  887. assert _PP is None, (
  888. "pipeline model parallel group is already initialized")
  889. group_ranks = []
  890. for i in range(num_pipeline_model_parallel_groups):
  891. ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
  892. group_ranks.append(ranks)
  893. # pipeline parallel does not need custom allreduce
  894. _PP = init_model_parallel_group(group_ranks,
  895. get_world_group().local_rank,
  896. backend,
  897. use_custom_allreduce=False,
  898. group_name="pp")
  899. def ensure_model_parallel_initialized(
  900. tensor_model_parallel_size: int,
  901. pipeline_model_parallel_size: int,
  902. backend: Optional[str] = None,
  903. ) -> None:
  904. """Helper to initialize model parallel groups if they are not initialized,
  905. or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
  906. values if the model parallel groups are initialized.
  907. """
  908. backend = backend or torch.distributed.get_backend(
  909. get_world_group().device_group)
  910. if not model_parallel_is_initialized():
  911. initialize_model_parallel(tensor_model_parallel_size,
  912. pipeline_model_parallel_size, backend)
  913. return
  914. assert (
  915. get_tensor_model_parallel_world_size() == tensor_model_parallel_size
  916. ), ("tensor parallel group already initialized, but of unexpected size: "
  917. f"{get_tensor_model_parallel_world_size()=} vs. "
  918. f"{tensor_model_parallel_size=}")
  919. pp_world_size = get_pp_group().world_size
  920. assert (pp_world_size == pipeline_model_parallel_size), (
  921. "pipeline parallel group already initialized, but of unexpected size: "
  922. f"{pp_world_size=} vs. "
  923. f"{pipeline_model_parallel_size=}")
  924. def model_parallel_is_initialized():
  925. """Check if tensor and pipeline parallel groups are initialized."""
  926. return (_TP is not None and _PP is not None)
  927. _TP_STATE_PATCHED = False
  928. @contextmanager
  929. def patch_tensor_parallel_group(tp_group: GroupCoordinator):
  930. """Patch the tp group temporarily until this function ends.
  931. This method is for draft workers of speculative decoding to run draft model
  932. with different tp degree from that of target model workers.
  933. Args:
  934. tp_group (GroupCoordinator): the tp group coordinator
  935. """
  936. global _TP_STATE_PATCHED
  937. assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
  938. _TP_STATE_PATCHED = True
  939. old_tp_group = get_tp_group()
  940. global _TP
  941. _TP = tp_group
  942. try:
  943. yield
  944. finally:
  945. # restore the original state
  946. _TP_STATE_PATCHED = False
  947. _TP = old_tp_group
  948. def get_tensor_model_parallel_world_size():
  949. """Return world size for the tensor model parallel group."""
  950. return get_tp_group().world_size
  951. def get_tensor_model_parallel_rank():
  952. """Return my rank for the tensor model parallel group."""
  953. return get_tp_group().rank_in_group
  954. def destroy_model_parallel():
  955. """Set the groups to none and destroy them."""
  956. global _TP
  957. if _TP:
  958. _TP.destroy()
  959. _TP = None
  960. global _PP
  961. if _PP:
  962. _PP.destroy()
  963. _PP = None
  964. def destroy_distributed_environment():
  965. global _WORLD
  966. if _WORLD:
  967. _WORLD.destroy()
  968. _WORLD = None
  969. if torch.distributed.is_initialized():
  970. torch.distributed.destroy_process_group()
  971. def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
  972. """
  973. This is a collective operation that returns if each rank is in the same node
  974. as the source rank. It tests if processes are attached to the same
  975. memory system (shared access to shared memory).
  976. """
  977. assert torch.distributed.get_backend(
  978. pg) != torch.distributed.Backend.NCCL, (
  979. "in_the_same_node_as should be tested with a non-NCCL group.")
  980. # local rank inside the group
  981. rank = torch.distributed.get_rank(group=pg)
  982. world_size = torch.distributed.get_world_size(group=pg)
  983. # local tensor in each process to store the result
  984. is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
  985. # global ranks of the processes in the group
  986. ranks = torch.distributed.get_process_group_ranks(pg)
  987. magic_message = b"magic_message"
  988. shm = None
  989. try:
  990. with contextlib.suppress(OSError):
  991. if rank == source_rank:
  992. # create a shared memory segment
  993. shm = shared_memory.SharedMemory(create=True, size=128)
  994. shm.buf[:len(magic_message)] = magic_message
  995. torch.distributed.broadcast_object_list([shm.name],
  996. src=ranks[source_rank],
  997. group=pg)
  998. is_in_the_same_node[rank] = 1
  999. else:
  1000. # try to open the shared memory segment
  1001. recv = [None]
  1002. torch.distributed.broadcast_object_list(recv,
  1003. src=ranks[source_rank],
  1004. group=pg)
  1005. name = recv[0]
  1006. # fix to https://stackoverflow.com/q/62748654/9191338
  1007. # Python incorrectly tracks shared memory even if it is not
  1008. # created by the process. The following patch is a workaround.
  1009. with patch("multiprocessing.resource_tracker.register",
  1010. lambda *args, **kwargs: None):
  1011. shm = shared_memory.SharedMemory(name=name)
  1012. if shm.buf[:len(magic_message)] == magic_message:
  1013. is_in_the_same_node[rank] = 1
  1014. except Exception as e:
  1015. logger.error(f"Error ignored in is_in_the_same_node: {e}")
  1016. finally:
  1017. if shm:
  1018. shm.close()
  1019. torch.distributed.barrier(group=pg)
  1020. # clean up the shared memory segment
  1021. with contextlib.suppress(OSError):
  1022. if rank == source_rank and shm:
  1023. shm.unlink()
  1024. torch.distributed.all_reduce(is_in_the_same_node, group=pg)
  1025. return [x == 1 for x in is_in_the_same_node.tolist()]
  1026. def get_current_tp_rank_partition_offset(total_size: int,
  1027. tp_rank: Optional[int] = None,
  1028. tp_size: Optional[int] = None,
  1029. multiple_of: int = 1) -> int:
  1030. if tp_rank is None:
  1031. tp_rank = get_tensor_model_parallel_rank()
  1032. if tp_size is None:
  1033. tp_size = get_tensor_model_parallel_world_size()
  1034. assert total_size % multiple_of == 0
  1035. total_size = total_size // multiple_of
  1036. return ((total_size // tp_size) * tp_rank +
  1037. min(total_size % tp_size, tp_rank)) * multiple_of
  1038. def get_current_tp_rank_partition_size(total_size: int,
  1039. tp_rank: Optional[int] = None,
  1040. tp_size: Optional[int] = None,
  1041. multiple_of: int = 1) -> int:
  1042. if tp_rank is None:
  1043. tp_rank = get_tensor_model_parallel_rank()
  1044. if tp_size is None:
  1045. tp_size = get_tensor_model_parallel_world_size()
  1046. assert total_size % multiple_of == 0
  1047. total_size = total_size // multiple_of
  1048. return ((total_size // tp_size) +
  1049. (total_size % tp_size > tp_rank)) * multiple_of