parallel_state.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from
  4. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
  5. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  6. """Aphrodite distributed state.
  7. It takes over the control of the distributed environment from PyTorch.
  8. The typical workflow is:
  9. - call `init_distributed_environment` to initialize the distributed environment.
  10. - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
  11. initialize the model parallel groups.
  12. - any code dealing with the distributed stuff
  13. - call `destroy_model_parallel` to destroy the model parallel groups.
  14. - call `destroy_distributed_environment` to destroy the distributed environment.
  15. If you only need to use the distributed environment without model/pipeline
  16. parallelism, you can skip the model parallel initialization and destruction
  17. steps.
  18. """
  19. import contextlib
  20. import pickle
  21. import os
  22. from collections import namedtuple
  23. from contextlib import contextmanager, nullcontext
  24. from dataclasses import dataclass
  25. from multiprocessing import shared_memory
  26. from typing import Any, Dict, List, Optional, Tuple, Union
  27. from unittest.mock import patch
  28. import torch
  29. import torch.distributed
  30. from loguru import logger
  31. from torch.distributed import Backend, ProcessGroup
  32. @dataclass
  33. class GraphCaptureContext:
  34. stream: torch.cuda.Stream
  35. TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
  36. def _split_tensor_dict(
  37. tensor_dict: Dict[str, Union[torch.Tensor, Any]],
  38. prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
  39. """Split the tensor dictionary into two parts:
  40. 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
  41. by its metadata.
  42. 2. A list of tensors.
  43. If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
  44. metadata will be "key1%key2".
  45. """
  46. metadata_list: List[Tuple[str, Any]] = []
  47. tensor_list = []
  48. for key, value in tensor_dict.items():
  49. assert "%" not in key, (
  50. "Avoid having '%' in key "
  51. "as it is used as a separator for nested entries.")
  52. if isinstance(value, torch.Tensor):
  53. # Note: we cannot use `value.device` here,
  54. # because it contains not only the device type but also the device
  55. # index (e.g. "cuda:0"). We only need the device type.
  56. # receiving side will set the device index.
  57. device = value.device.type
  58. metadata_list.append(
  59. (prefix + key, TensorMetadata(device, value.dtype,
  60. value.size())))
  61. tensor_list.append(value)
  62. elif isinstance(value, dict):
  63. if len(value) == 0:
  64. metadata_list.append((prefix + key, value))
  65. inner_metadata_list, inner_tensor_list = _split_tensor_dict(
  66. value, prefix + key + "%")
  67. metadata_list.extend(inner_metadata_list)
  68. tensor_list.extend(inner_tensor_list)
  69. else:
  70. metadata_list.append((prefix + key, value))
  71. return metadata_list, tensor_list
  72. def _update_nested_dict(nested_dict, flattened_key, value):
  73. key_splits = flattened_key.split("%")
  74. cur_dict = nested_dict
  75. for k in key_splits[:-1]:
  76. if k not in cur_dict:
  77. cur_dict[k] = {}
  78. cur_dict = cur_dict[k]
  79. cur_dict[key_splits[-1]] = value
  80. class GroupCoordinator:
  81. """
  82. PyTorch ProcessGroup wrapper for a group of processes.
  83. PyTorch ProcessGroup is bound to one specific communication backend,
  84. e.g. NCCL, Gloo, MPI, etc.
  85. GroupCoordinator takes charge of all the communication operations among
  86. the processes in the group. It can route the communication to
  87. a specific implementation (e.g. switch allreduce implementation
  88. based on the tensor size and cuda graph mode).
  89. """
  90. # available attributes:
  91. rank: int # global rank
  92. ranks: List[int] # global ranks in the group
  93. world_size: int # size of the group
  94. # difference between `local_rank` and `rank_in_group`:
  95. # if we have a group of size 4 across two nodes:
  96. # Process | Node | Rank | Local Rank | Rank in Group
  97. # 0 | 0 | 0 | 0 | 0
  98. # 1 | 0 | 1 | 1 | 1
  99. # 2 | 1 | 2 | 0 | 2
  100. # 3 | 1 | 3 | 1 | 3
  101. local_rank: int # local rank used to assign devices
  102. rank_in_group: int # rank inside the group
  103. cpu_group: ProcessGroup # group for CPU communication
  104. device_group: ProcessGroup # group for device communication
  105. use_pynccl: bool # a hint of whether to use PyNccl
  106. use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
  107. # communicators are only created for world size > 1
  108. pynccl_comm: Optional[Any] # PyNccl communicator
  109. ca_comm: Optional[Any] # Custom allreduce communicator
  110. mq_broadcaster: Optional[Any] # shared memory broadcaster
  111. def __init__(
  112. self,
  113. group_ranks: List[List[int]],
  114. local_rank: int,
  115. torch_distributed_backend: Union[str, Backend],
  116. use_pynccl: bool,
  117. use_custom_allreduce: bool,
  118. use_message_queue_broadcaster: bool = False,
  119. ):
  120. self.rank = torch.distributed.get_rank()
  121. self.local_rank = local_rank
  122. self.device_group = None
  123. self.cpu_group = None
  124. for ranks in group_ranks:
  125. device_group = torch.distributed.new_group(
  126. ranks, backend=torch_distributed_backend)
  127. # a group with `gloo` backend, to allow direct coordination between
  128. # processes through the CPU.
  129. cpu_group = torch.distributed.new_group(ranks, backend="gloo")
  130. if self.rank in ranks:
  131. self.ranks = ranks
  132. self.world_size = len(ranks)
  133. self.rank_in_group = ranks.index(self.rank)
  134. self.device_group = device_group
  135. self.cpu_group = cpu_group
  136. assert self.cpu_group is not None
  137. assert self.device_group is not None
  138. if torch.cuda.is_available():
  139. self.device = torch.device(f"cuda:{local_rank}")
  140. else:
  141. self.device = torch.device("cpu")
  142. self.use_pynccl = use_pynccl
  143. self.use_custom_allreduce = use_custom_allreduce
  144. # lazy import to avoid documentation build error
  145. from aphrodite.distributed.device_communicators.custom_all_reduce import \
  146. CustomAllreduce # noqa: E501
  147. from aphrodite.distributed.device_communicators.pynccl import \
  148. PyNcclCommunicator
  149. self.pynccl_comm: Optional[PyNcclCommunicator]
  150. if use_pynccl and self.world_size > 1:
  151. self.pynccl_comm = PyNcclCommunicator(
  152. group=self.cpu_group,
  153. device=self.device,
  154. )
  155. else:
  156. self.pynccl_comm = None
  157. self.ca_comm: Optional[CustomAllreduce]
  158. if use_custom_allreduce and self.world_size > 1:
  159. # Initialize a custom fast all-reduce implementation.
  160. self.ca_comm = CustomAllreduce(
  161. group=self.cpu_group,
  162. device=self.device,
  163. )
  164. else:
  165. self.ca_comm = None
  166. from aphrodite.distributed.device_communicators.shm_broadcast import (
  167. MessageQueue)
  168. self.mq_broadcaster: Optional[MessageQueue] = None
  169. if use_message_queue_broadcaster and self.world_size > 1:
  170. self.mq_broadcaster = MessageQueue.create_from_process_group(
  171. self.cpu_group, 1 << 22, 6)
  172. @property
  173. def first_rank(self):
  174. """Return the global rank of the first process in the group"""
  175. return self.ranks[0]
  176. @property
  177. def last_rank(self):
  178. """Return the global rank of the last process in the group"""
  179. return self.ranks[-1]
  180. @property
  181. def is_first_rank(self):
  182. """Return whether the caller is the first process in the group"""
  183. return self.rank == self.first_rank
  184. @property
  185. def is_last_rank(self):
  186. """Return whether the caller is the last process in the group"""
  187. return self.rank == self.last_rank
  188. @property
  189. def next_rank(self):
  190. """Return the global rank of the process that follows the caller"""
  191. rank_in_group = self.rank_in_group
  192. world_size = self.world_size
  193. return self.ranks[(rank_in_group + 1) % world_size]
  194. @property
  195. def prev_rank(self):
  196. """Return the global rank of the process that precedes the caller"""
  197. rank_in_group = self.rank_in_group
  198. world_size = self.world_size
  199. return self.ranks[(rank_in_group - 1) % world_size]
  200. @contextmanager
  201. def graph_capture(
  202. self, graph_capture_context: Optional[GraphCaptureContext] = None):
  203. if graph_capture_context is None:
  204. stream = torch.cuda.Stream()
  205. graph_capture_context = GraphCaptureContext(stream)
  206. else:
  207. stream = graph_capture_context.stream
  208. ca_comm = self.ca_comm
  209. maybe_ca_context = nullcontext(
  210. ) if ca_comm is None else ca_comm.capture()
  211. with torch.cuda.stream(stream), maybe_ca_context:
  212. # In graph mode, we have to be very careful about the collective
  213. # operations. The current status is:
  214. # allreduce \ Mode | Eager | Graph |
  215. # --------------------------------------------
  216. # custom allreduce | enabled | enabled |
  217. # PyNccl | disabled| enabled |
  218. # torch.distributed | enabled | disabled|
  219. #
  220. # Note that custom allreduce will have a runtime check, if the
  221. # tensor size is too large, it will fallback to the next
  222. # available option.
  223. # In summary: When using CUDA graph, we use
  224. # either custom all-reduce kernel or pynccl. When not using
  225. # CUDA graph, we use either custom all-reduce kernel or
  226. # PyTorch NCCL. We always prioritize using custom all-reduce
  227. # kernel but fall back to PyTorch or pynccl if it is
  228. # disabled or not supported.
  229. pynccl_comm = self.pynccl_comm
  230. maybe_pynccl_context: Any
  231. if not pynccl_comm:
  232. maybe_pynccl_context = nullcontext()
  233. else:
  234. maybe_pynccl_context = pynccl_comm.change_state(
  235. enable=True, stream=torch.cuda.current_stream())
  236. with maybe_pynccl_context:
  237. yield graph_capture_context
  238. def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
  239. """
  240. NOTE: This operation will be applied in-place or out-of-place.
  241. Always assume this function modifies its input, but use the return
  242. value as the output.
  243. """
  244. ca_comm = self.ca_comm
  245. # Bypass the function if we are using only 1 GPU.
  246. if self.world_size == 1:
  247. return input_
  248. if ca_comm is not None:
  249. out = ca_comm.custom_all_reduce(input_)
  250. if out is not None:
  251. return out
  252. pynccl_comm = self.pynccl_comm
  253. if (pynccl_comm is not None and not pynccl_comm.disabled):
  254. pynccl_comm.all_reduce(input_)
  255. else:
  256. torch.distributed.all_reduce(input_, group=self.device_group)
  257. return input_
  258. def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
  259. world_size = self.world_size
  260. # Bypass the function if we are using only 1 GPU.
  261. if world_size == 1:
  262. return input_
  263. assert -input_.dim() <= dim < input_.dim(), (
  264. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  265. if dim < 0:
  266. # Convert negative dim to positive.
  267. dim += input_.dim()
  268. input_size = input_.size()
  269. # Allocate output tensor.
  270. output_tensor = torch.empty((world_size, ) + input_size,
  271. dtype=input_.dtype,
  272. device=input_.device)
  273. # All-gather.
  274. torch.distributed.all_gather_into_tensor(output_tensor,
  275. input_,
  276. group=self.device_group)
  277. # Reshape
  278. output_tensor = output_tensor.movedim(0, dim)
  279. output_tensor = output_tensor.reshape(input_size[:dim] +
  280. (world_size *
  281. input_size[dim], ) +
  282. input_size[dim + 1:])
  283. return output_tensor
  284. def gather(self,
  285. input_: torch.Tensor,
  286. dst: int = 0,
  287. dim: int = -1) -> torch.Tensor:
  288. """
  289. NOTE: We assume that the input tensor is on the same device across
  290. all the ranks.
  291. NOTE: `dst` is the local rank of the destination rank.
  292. """
  293. world_size = self.world_size
  294. # Bypass the function if we are using only 1 GPU.
  295. if world_size == 1:
  296. return input_
  297. assert -input_.dim() <= dim < input_.dim(), (
  298. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  299. if dim < 0:
  300. # Convert negative dim to positive.
  301. dim += input_.dim()
  302. # Allocate output tensor.
  303. if self.rank_in_group == dst:
  304. gather_list = [torch.empty_like(input_) for _ in range(world_size)]
  305. else:
  306. gather_list = None
  307. # Gather.
  308. torch.distributed.gather(input_,
  309. gather_list,
  310. dst=self.ranks[dst],
  311. group=self.device_group)
  312. if self.rank_in_group == dst:
  313. output_tensor = torch.cat(gather_list, dim=dim)
  314. else:
  315. output_tensor = None
  316. return output_tensor
  317. def broadcast(self, input_: torch.Tensor, src: int = 0):
  318. """Broadcast the input tensor.
  319. NOTE: `src` is the local rank of the source rank.
  320. """
  321. assert src < self.world_size, f"Invalid src rank ({src})"
  322. # Bypass the function if we are using only 1 GPU.
  323. if self.world_size == 1:
  324. return input_
  325. # Broadcast.
  326. torch.distributed.broadcast(input_,
  327. src=self.ranks[src],
  328. group=self.device_group)
  329. return input_
  330. def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
  331. """Broadcast the input object.
  332. NOTE: `src` is the local rank of the source rank.
  333. """
  334. assert src < self.world_size, f"Invalid src rank ({src})"
  335. # Bypass the function if we are using only 1 GPU.
  336. if self.world_size == 1:
  337. return obj
  338. if self.mq_broadcaster is not None:
  339. assert src == 0, "Message queue broadcaster only supports src=0"
  340. return self.mq_broadcaster.broadcast_object(obj)
  341. if self.rank_in_group == src:
  342. torch.distributed.broadcast_object_list([obj],
  343. src=self.ranks[src],
  344. group=self.cpu_group)
  345. return obj
  346. else:
  347. recv = [None]
  348. torch.distributed.broadcast_object_list(recv,
  349. src=self.ranks[src],
  350. group=self.cpu_group)
  351. return recv[0]
  352. def broadcast_object_list(self,
  353. obj_list: List[Any],
  354. src: int = 0,
  355. group: Optional[ProcessGroup] = None):
  356. """Broadcast the input object list.
  357. NOTE: `src` is the local rank of the source rank.
  358. """
  359. assert src < self.world_size, f"Invalid src rank ({src})"
  360. # Bypass the function if we are using only 1 GPU.
  361. if self.world_size == 1:
  362. return obj_list
  363. # Broadcast.
  364. torch.distributed.broadcast_object_list(obj_list,
  365. src=self.ranks[src],
  366. group=self.device_group)
  367. return obj_list
  368. def send_object(self, obj: Any, dst: int) -> None:
  369. """Send the input object list to the destination rank."""
  370. """NOTE: `dst` is the local rank of the destination rank."""
  371. assert dst < self.world_size, f"Invalid dst rank ({dst})"
  372. assert dst != self.rank_in_group, (
  373. "Invalid destination rank. Destination rank is the same "
  374. "as the current rank.")
  375. # Serialize object to tensor and get the size as well
  376. object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
  377. size_tensor = torch.tensor([object_tensor.numel()],
  378. dtype=torch.long,
  379. device="cpu")
  380. # Send object size
  381. torch.distributed.send(size_tensor,
  382. dst=self.ranks[dst],
  383. group=self.cpu_group)
  384. # Send object
  385. torch.distributed.send(object_tensor,
  386. dst=self.ranks[dst],
  387. group=self.cpu_group)
  388. return None
  389. def recv_object(self, src: int) -> Any:
  390. """Receive the input object list from the source rank."""
  391. """NOTE: `src` is the local rank of the source rank."""
  392. assert src < self.world_size, f"Invalid src rank ({src})"
  393. assert src != self.rank_in_group, (
  394. "Invalid source rank. Source rank is the same as the current rank."
  395. )
  396. size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
  397. # Receive object size
  398. rank_size = torch.distributed.recv(size_tensor,
  399. src=self.ranks[src],
  400. group=self.cpu_group)
  401. # Tensor to receive serialized objects into.
  402. object_tensor = torch.empty( # type: ignore[call-overload]
  403. size_tensor.item(), # type: ignore[arg-type]
  404. dtype=torch.uint8,
  405. device="cpu")
  406. rank_object = torch.distributed.recv(object_tensor,
  407. src=self.ranks[src],
  408. group=self.cpu_group)
  409. assert rank_object == rank_size, (
  410. "Received object sender rank does not match the size sender rank.")
  411. obj = pickle.loads(object_tensor.numpy().tobytes())
  412. return obj
  413. def broadcast_tensor_dict(
  414. self,
  415. tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
  416. src: int = 0,
  417. group: Optional[ProcessGroup] = None,
  418. metadata_group: Optional[ProcessGroup] = None
  419. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  420. """Broadcast the input tensor dictionary.
  421. NOTE: `src` is the local rank of the source rank.
  422. """
  423. # Bypass the function if we are using only 1 GPU.
  424. if (not torch.distributed.is_initialized() or self.world_size == 1):
  425. return tensor_dict
  426. group = self.device_group
  427. metadata_group = self.cpu_group
  428. assert src < self.world_size, f"Invalid src rank ({src})"
  429. rank_in_group = self.rank_in_group
  430. if rank_in_group == src:
  431. metadata_list: List[Tuple[Any, Any]] = []
  432. assert isinstance(
  433. tensor_dict,
  434. dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
  435. metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
  436. # `metadata_list` lives in CPU memory.
  437. # `broadcast_object_list` has serialization & deserialization,
  438. # all happening on CPU. Therefore, we can use the CPU group.
  439. self.broadcast_object(metadata_list, src=src)
  440. async_handles = []
  441. for tensor in tensor_list:
  442. if tensor.numel() == 0:
  443. # Skip broadcasting empty tensors.
  444. continue
  445. if tensor.is_cpu:
  446. # use metadata_group for CPU tensors
  447. handle = torch.distributed.broadcast(tensor,
  448. src=self.ranks[src],
  449. group=metadata_group,
  450. async_op=True)
  451. else:
  452. # use group for GPU tensors
  453. handle = torch.distributed.broadcast(tensor,
  454. src=self.ranks[src],
  455. group=group,
  456. async_op=True)
  457. async_handles.append(handle)
  458. for async_handle in async_handles:
  459. async_handle.wait()
  460. else:
  461. metadata_list = self.broadcast_object(None, src=src)
  462. tensor_dict = {}
  463. async_handles = []
  464. for key, value in metadata_list:
  465. if isinstance(value, TensorMetadata):
  466. tensor = torch.empty(value.size,
  467. dtype=value.dtype,
  468. device=value.device)
  469. if tensor.numel() == 0:
  470. # Skip broadcasting empty tensors.
  471. _update_nested_dict(tensor_dict, key, tensor)
  472. continue
  473. if tensor.is_cpu:
  474. # use metadata_group for CPU tensors
  475. handle = torch.distributed.broadcast(
  476. tensor,
  477. src=self.ranks[src],
  478. group=metadata_group,
  479. async_op=True)
  480. else:
  481. # use group for GPU tensors
  482. handle = torch.distributed.broadcast(
  483. tensor,
  484. src=self.ranks[src],
  485. group=group,
  486. async_op=True)
  487. async_handles.append(handle)
  488. _update_nested_dict(tensor_dict, key, tensor)
  489. else:
  490. _update_nested_dict(tensor_dict, key, value)
  491. for async_handle in async_handles:
  492. async_handle.wait()
  493. return tensor_dict
  494. def send_tensor_dict(
  495. self,
  496. tensor_dict: Dict[str, Union[torch.Tensor, Any]],
  497. dst: Optional[int] = None
  498. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  499. """Send the input tensor dictionary.
  500. NOTE: `dst` is the local rank of the source rank.
  501. """
  502. # Bypass the function if we are using only 1 GPU.
  503. if not torch.distributed.is_initialized() or self.world_size == 1:
  504. return tensor_dict
  505. group = self.device_group
  506. metadata_group = self.cpu_group
  507. if dst is None:
  508. dst = (self.rank_in_group + 1) % self.world_size
  509. assert dst < self.world_size, f"Invalid dst rank ({dst})"
  510. metadata_list: List[Tuple[Any, Any]] = []
  511. assert isinstance(
  512. tensor_dict,
  513. dict), f"Expecting a dictionary, got {type(tensor_dict)}"
  514. metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
  515. # `metadata_list` lives in CPU memory.
  516. # `send_object_list` has serialization & deserialization,
  517. # all happening on CPU. Therefore, we can use the CPU group.
  518. self.send_object(metadata_list, dst=dst)
  519. for tensor in tensor_list:
  520. if tensor.numel() == 0:
  521. # Skip sending empty tensors.
  522. continue
  523. if tensor.is_cpu:
  524. # use metadata_group for CPU tensors
  525. torch.distributed.send(tensor,
  526. dst=self.ranks[dst],
  527. group=metadata_group)
  528. else:
  529. # use group for GPU tensors
  530. torch.distributed.send(tensor,
  531. dst=self.ranks[dst],
  532. group=group)
  533. return None
  534. def recv_tensor_dict(
  535. self,
  536. src: Optional[int] = None
  537. ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
  538. """Recv the input tensor dictionary.
  539. NOTE: `src` is the local rank of the source rank.
  540. """
  541. # Bypass the function if we are using only 1 GPU.
  542. if not torch.distributed.is_initialized() or self.world_size == 1:
  543. return None
  544. group = self.device_group
  545. metadata_group = self.cpu_group
  546. if src is None:
  547. src = (self.rank_in_group - 1) % self.world_size
  548. assert src < self.world_size, f"Invalid src rank ({src})"
  549. recv_metadata_list = self.recv_object(src=src)
  550. tensor_dict: Dict[str, Any] = {}
  551. for key, value in recv_metadata_list:
  552. if isinstance(value, TensorMetadata):
  553. tensor = torch.empty(value.size,
  554. dtype=value.dtype,
  555. device=value.device)
  556. if tensor.numel() == 0:
  557. # Skip broadcasting empty tensors.
  558. _update_nested_dict(tensor_dict, key, tensor)
  559. continue
  560. if tensor.is_cpu:
  561. # use metadata_group for CPU tensors
  562. torch.distributed.recv(tensor,
  563. src=self.ranks[src],
  564. group=metadata_group)
  565. else:
  566. # use group for GPU tensors
  567. torch.distributed.recv(tensor,
  568. src=self.ranks[src],
  569. group=group)
  570. _update_nested_dict(tensor_dict, key, tensor)
  571. else:
  572. _update_nested_dict(tensor_dict, key, value)
  573. return tensor_dict
  574. def barrier(self):
  575. """Barrier synchronization among the group.
  576. NOTE: don't use `device_group` here! `barrier` in NCCL is
  577. terrible because it is internally a broadcast operation with
  578. secretly created GPU tensors. It is easy to mess up the current
  579. device. Use the CPU group instead.
  580. """
  581. torch.distributed.barrier(group=self.cpu_group)
  582. def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
  583. """Sends a tensor to the destination rank in a non-blocking way"""
  584. """NOTE: `dst` is the local rank of the destination rank."""
  585. if dst is None:
  586. dst = (self.rank_in_group + 1) % self.world_size
  587. pynccl_comm = self.pynccl_comm
  588. if pynccl_comm is not None and not pynccl_comm.disabled:
  589. pynccl_comm.send(tensor, dst)
  590. else:
  591. torch.distributed.send(tensor, self.ranks[dst], self.device_group)
  592. def recv(self,
  593. size: torch.Size,
  594. dtype: torch.dtype,
  595. src: Optional[int] = None) -> torch.Tensor:
  596. """Receives a tensor from the src rank."""
  597. """NOTE: `src` is the local rank of the destination rank."""
  598. if src is None:
  599. src = (self.rank_in_group - 1) % self.world_size
  600. tensor = torch.empty(size, dtype=dtype, device=self.device)
  601. pynccl_comm = self.pynccl_comm
  602. if pynccl_comm is not None and not pynccl_comm.disabled:
  603. pynccl_comm.recv(tensor, src)
  604. else:
  605. torch.distributed.recv(tensor, self.ranks[src], self.device_group)
  606. return tensor
  607. def destroy(self):
  608. if self.device_group is not None:
  609. torch.distributed.destroy_process_group(self.device_group)
  610. self.device_group = None
  611. if self.cpu_group is not None:
  612. torch.distributed.destroy_process_group(self.cpu_group)
  613. self.cpu_group = None
  614. if self.pynccl_comm is not None:
  615. self.pynccl_comm = None
  616. if self.ca_comm is not None:
  617. self.ca_comm = None
  618. if self.mq_broadcaster is not None:
  619. self.mq_broadcaster = None
  620. _WORLD: Optional[GroupCoordinator] = None
  621. def get_world_group() -> GroupCoordinator:
  622. assert _WORLD is not None, ("world group is not initialized")
  623. return _WORLD
  624. def init_world_group(ranks: List[int], local_rank: int,
  625. backend: str) -> GroupCoordinator:
  626. return GroupCoordinator(
  627. group_ranks=[ranks],
  628. local_rank=local_rank,
  629. torch_distributed_backend=backend,
  630. use_pynccl=False,
  631. use_custom_allreduce=False,
  632. )
  633. def init_model_parallel_group(
  634. group_ranks: List[List[int]],
  635. local_rank: int,
  636. backend: str,
  637. use_custom_allreduce: Optional[bool] = None,
  638. use_message_queue_broadcaster: bool = False,
  639. ) -> GroupCoordinator:
  640. if use_custom_allreduce is None:
  641. use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
  642. return GroupCoordinator(
  643. group_ranks=group_ranks,
  644. local_rank=local_rank,
  645. torch_distributed_backend=backend,
  646. use_pynccl=True,
  647. use_custom_allreduce=use_custom_allreduce,
  648. use_message_queue_broadcaster=use_message_queue_broadcaster,
  649. )
  650. _TP: Optional[GroupCoordinator] = None
  651. def get_tp_group() -> GroupCoordinator:
  652. assert _TP is not None, ("tensor model parallel group is not initialized")
  653. return _TP
  654. # kept for backward compatibility
  655. get_tensor_model_parallel_group = get_tp_group
  656. _PP: Optional[GroupCoordinator] = None
  657. def get_pp_group() -> GroupCoordinator:
  658. assert _PP is not None, (
  659. "pipeline model parallel group is not initialized")
  660. return _PP
  661. # kept for backward compatibility
  662. get_pipeline_model_parallel_group = get_pp_group
  663. @contextmanager
  664. def graph_capture():
  665. """
  666. `graph_capture` is a context manager which should surround the code that
  667. is capturing the CUDA graph. Its main purpose is to ensure that the
  668. some operations will be run after the graph is captured, before the graph
  669. is replayed. It returns a `GraphCaptureContext` object which contains the
  670. necessary data for the graph capture. Currently, it only contains the
  671. stream that the graph capture is running on. This stream is set to the
  672. current CUDA stream when the context manager is entered and reset to the
  673. default stream when the context manager is exited. This is to ensure that
  674. the graph capture is running on a separate stream from the default stream,
  675. in order to explicitly distinguish the kernels to capture
  676. from other kernels possibly launched on background in the default stream.
  677. """
  678. with get_tp_group().graph_capture() as context, get_pp_group(
  679. ).graph_capture(context):
  680. yield context
  681. _ENABLE_CUSTOM_ALL_REDUCE = True
  682. def set_custom_all_reduce(enable: bool):
  683. global _ENABLE_CUSTOM_ALL_REDUCE
  684. _ENABLE_CUSTOM_ALL_REDUCE = enable
  685. def init_distributed_environment(
  686. world_size: int = -1,
  687. rank: int = -1,
  688. distributed_init_method: str = "env://",
  689. local_rank: int = -1,
  690. backend: str = "nccl",
  691. ):
  692. logger.debug(
  693. f"world_size={world_size} rank={rank} local_rank={local_rank} "
  694. f"distributed_init_method={distributed_init_method} backend={backend}")
  695. if not torch.distributed.is_initialized():
  696. assert distributed_init_method is not None, (
  697. "distributed_init_method must be provided when initializing "
  698. "distributed environment")
  699. # this backend is used for WORLD
  700. torch.distributed.init_process_group(
  701. backend=backend,
  702. init_method=distributed_init_method,
  703. world_size=world_size,
  704. rank=rank)
  705. # set the local rank
  706. # local_rank is not available in torch ProcessGroup,
  707. # see https://github.com/pytorch/pytorch/issues/122816
  708. if local_rank == -1:
  709. # local rank not set, this usually happens in single-node
  710. # setting, where we can use rank as local rank
  711. if distributed_init_method == "env://":
  712. local_rank = os.getenv("LOCAL_RANK", rank)
  713. else:
  714. local_rank = rank
  715. global _WORLD
  716. if _WORLD is None:
  717. ranks = list(range(torch.distributed.get_world_size()))
  718. _WORLD = init_world_group(ranks, local_rank, backend)
  719. else:
  720. assert _WORLD.world_size == torch.distributed.get_world_size(), (
  721. "world group already initialized with a different world size")
  722. def initialize_model_parallel(
  723. tensor_model_parallel_size: int = 1,
  724. pipeline_model_parallel_size: int = 1,
  725. backend: Optional[str] = None,
  726. ) -> None:
  727. """
  728. Initialize model parallel groups.
  729. Arguments:
  730. tensor_model_parallel_size: number of GPUs used for tensor model
  731. parallelism.
  732. pipeline_model_parallel_size: number of GPUs used for pipeline model
  733. parallelism.
  734. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
  735. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  736. the model pipeline. The present function will
  737. create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
  738. 4 tensor model-parallel groups:
  739. [g0, g1], [g2, g3], [g4, g5], [g6, g7]
  740. 2 pipeline model-parallel groups:
  741. [g0, g2, g4, g6], [g1, g3, g5, g7]
  742. Note that for efficiency, the caller should make sure adjacent ranks
  743. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  744. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  745. ranks 8 to 15 belong to the second box.
  746. """
  747. # Get world size and rank. Ensure some consistencies.
  748. assert torch.distributed.is_initialized()
  749. world_size: int = torch.distributed.get_world_size()
  750. backend = backend or torch.distributed.get_backend(
  751. get_world_group().device_group)
  752. if (world_size !=
  753. tensor_model_parallel_size * pipeline_model_parallel_size):
  754. raise RuntimeError(
  755. f"world_size ({world_size}) is not equal to "
  756. f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
  757. f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
  758. # Build the tensor model-parallel groups.
  759. num_tensor_model_parallel_groups: int = (world_size //
  760. tensor_model_parallel_size)
  761. global _TP
  762. assert _TP is None, ("tensor model parallel group is already initialized")
  763. group_ranks = []
  764. for i in range(num_tensor_model_parallel_groups):
  765. ranks = list(
  766. range(i * tensor_model_parallel_size,
  767. (i + 1) * tensor_model_parallel_size))
  768. group_ranks.append(ranks)
  769. # message queue broadcaster is only used in tensor model parallel group
  770. _TP = init_model_parallel_group(group_ranks,
  771. get_world_group().local_rank,
  772. backend,
  773. use_message_queue_broadcaster=True)
  774. # Build the pipeline model-parallel groups.
  775. num_pipeline_model_parallel_groups: int = (world_size //
  776. pipeline_model_parallel_size)
  777. global _PP
  778. assert _PP is None, (
  779. "pipeline model parallel group is already initialized")
  780. group_ranks = []
  781. for i in range(num_pipeline_model_parallel_groups):
  782. ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
  783. group_ranks.append(ranks)
  784. # pipeline parallel does not need custom allreduce
  785. _PP = init_model_parallel_group(group_ranks,
  786. get_world_group().local_rank,
  787. backend,
  788. use_custom_allreduce=False)
  789. def ensure_model_parallel_initialized(
  790. tensor_model_parallel_size: int,
  791. pipeline_model_parallel_size: int,
  792. backend: Optional[str] = None,
  793. ) -> None:
  794. """Helper to initialize model parallel groups if they are not initialized,
  795. or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
  796. values if the model parallel groups are initialized.
  797. """
  798. backend = backend or torch.distributed.get_backend(
  799. get_world_group().device_group)
  800. if not model_parallel_is_initialized():
  801. initialize_model_parallel(tensor_model_parallel_size,
  802. pipeline_model_parallel_size, backend)
  803. return
  804. assert (
  805. get_tensor_model_parallel_world_size() == tensor_model_parallel_size
  806. ), ("tensor parallel group already initialized, but of unexpected size: "
  807. f"{get_tensor_model_parallel_world_size()=} vs. "
  808. f"{tensor_model_parallel_size=}")
  809. pp_world_size = get_pp_group().world_size
  810. assert (pp_world_size == pipeline_model_parallel_size), (
  811. "pipeline parallel group already initialized, but of unexpected size: "
  812. f"{pp_world_size=} vs. "
  813. f"{pipeline_model_parallel_size=}")
  814. def model_parallel_is_initialized():
  815. """Check if tensor and pipeline parallel groups are initialized."""
  816. return (_TP is not None and _PP is not None)
  817. _TP_STATE_PATCHED = False
  818. @contextmanager
  819. def patch_tensor_parallel_group(tp_group: GroupCoordinator):
  820. """Patch the tp group temporarily until this function ends.
  821. This method is for draft workers of speculative decoding to run draft model
  822. with different tp degree from that of target model workers.
  823. Args:
  824. tp_group (GroupCoordinator): the tp group coordinator
  825. """
  826. global _TP_STATE_PATCHED
  827. assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
  828. _TP_STATE_PATCHED = True
  829. old_tp_group = get_tp_group()
  830. global _TP
  831. _TP = tp_group
  832. try:
  833. yield
  834. finally:
  835. # restore the original state
  836. _TP_STATE_PATCHED = False
  837. _TP = old_tp_group
  838. def get_tensor_model_parallel_world_size():
  839. """Return world size for the tensor model parallel group."""
  840. return get_tp_group().world_size
  841. def get_tensor_model_parallel_rank():
  842. """Return my rank for the tensor model parallel group."""
  843. return get_tp_group().rank_in_group
  844. def destroy_model_parallel():
  845. """Set the groups to none and destroy them."""
  846. global _TP
  847. if _TP:
  848. _TP.destroy()
  849. _TP = None
  850. global _PP
  851. if _PP:
  852. _PP.destroy()
  853. _PP = None
  854. def destroy_distributed_environment():
  855. global _WORLD
  856. if _WORLD:
  857. _WORLD.destroy()
  858. _WORLD = None
  859. if torch.distributed.is_initialized():
  860. torch.distributed.destroy_process_group()
  861. def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
  862. """
  863. This is a collective operation that returns if each rank is in the same node
  864. as the source rank. It tests if processes are attached to the same
  865. memory system (shared access to shared memory).
  866. """
  867. assert torch.distributed.get_backend(
  868. pg) != torch.distributed.Backend.NCCL, (
  869. "in_the_same_node_as should be tested with a non-NCCL group.")
  870. # local rank inside the group
  871. rank = torch.distributed.get_rank(group=pg)
  872. world_size = torch.distributed.get_world_size(group=pg)
  873. # local tensor in each process to store the result
  874. is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
  875. # global ranks of the processes in the group
  876. ranks = torch.distributed.get_process_group_ranks(pg)
  877. magic_message = b"magic_message"
  878. shm = None
  879. try:
  880. with contextlib.suppress(OSError):
  881. if rank == source_rank:
  882. # create a shared memory segment
  883. shm = shared_memory.SharedMemory(create=True, size=128)
  884. shm.buf[:len(magic_message)] = magic_message
  885. torch.distributed.broadcast_object_list([shm.name],
  886. src=ranks[source_rank],
  887. group=pg)
  888. is_in_the_same_node[rank] = 1
  889. else:
  890. # try to open the shared memory segment
  891. recv = [None]
  892. torch.distributed.broadcast_object_list(recv,
  893. src=ranks[source_rank],
  894. group=pg)
  895. name = recv[0]
  896. # fix to https://stackoverflow.com/q/62748654/9191338
  897. # Python incorrectly tracks shared memory even if it is not
  898. # created by the process. The following patch is a workaround.
  899. with patch("multiprocessing.resource_tracker.register",
  900. lambda *args, **kwargs: None):
  901. shm = shared_memory.SharedMemory(name=name)
  902. if shm.buf[:len(magic_message)] == magic_message:
  903. is_in_the_same_node[rank] = 1
  904. except Exception as e:
  905. logger.error(f"Error ignored in is_in_the_same_node: {e}")
  906. finally:
  907. if shm:
  908. shm.close()
  909. torch.distributed.barrier(group=pg)
  910. # clean up the shared memory segment
  911. with contextlib.suppress(OSError):
  912. if rank == source_rank and shm:
  913. shm.unlink()
  914. torch.distributed.all_reduce(is_in_the_same_node, group=pg)
  915. return [x == 1 for x in is_in_the_same_node.tolist()]
  916. def get_current_tp_rank_partition_offset(total_size: int,
  917. tp_rank: Optional[int] = None,
  918. tp_size: Optional[int] = None,
  919. multiple_of: int = 1) -> int:
  920. if tp_rank is None:
  921. tp_rank = get_tensor_model_parallel_rank()
  922. if tp_size is None:
  923. tp_size = get_tensor_model_parallel_world_size()
  924. assert total_size % multiple_of == 0
  925. total_size = total_size // multiple_of
  926. return ((total_size // tp_size) * tp_rank +
  927. min(total_size % tp_size, tp_rank)) * multiple_of
  928. def get_current_tp_rank_partition_size(total_size: int,
  929. tp_rank: Optional[int] = None,
  930. tp_size: Optional[int] = None,
  931. multiple_of: int = 1) -> int:
  932. if tp_rank is None:
  933. tp_rank = get_tensor_model_parallel_rank()
  934. if tp_size is None:
  935. tp_size = get_tensor_model_parallel_world_size()
  936. assert total_size % multiple_of == 0
  937. total_size = total_size // multiple_of
  938. return ((total_size // tp_size) +
  939. (total_size % tp_size > tp_rank)) * multiple_of