communication_op.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. from collections import namedtuple
  2. from contextlib import contextmanager, nullcontext
  3. from dataclasses import dataclass
  4. from typing import Any, Dict, List, Optional, Tuple, Union
  5. import torch
  6. from torch.distributed import ProcessGroup
  7. from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
  8. get_tensor_model_parallel_group,
  9. get_tensor_model_parallel_rank,
  10. get_tensor_model_parallel_world_size,
  11. get_tp_ca_communicator,
  12. get_tp_pynccl_communicator)
  13. @dataclass
  14. class GraphCaptureContext:
  15. stream: torch.cuda.Stream
  16. @contextmanager
  17. def graph_capture():
  18. """
  19. `graph_capture` is a context manager which should surround the code that
  20. is capturing the CUDA graph. Its main purpose is to ensure that the
  21. some operations will be run after the graph is captured, before the graph
  22. is replayed. It returns a `GraphCaptureContext` object which contains the
  23. necessary data for the graph capture. Currently, it only contains the
  24. stream that the graph capture is running on. This stream is set to the
  25. current CUDA stream when the context manager is entered and reset to the
  26. default stream when the context manager is exited. This is to ensure that
  27. the graph capture is running on a separate stream from the default stream,
  28. in order to explicitly distinguish the kernels to capture
  29. from other kernels possibly launched on background in the default stream.
  30. """
  31. stream = torch.cuda.Stream()
  32. graph_capture_context = GraphCaptureContext(stream)
  33. ca_comm = get_tp_ca_communicator()
  34. maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
  35. with torch.cuda.stream(stream), maybe_ca_context:
  36. # In graph mode, we have to be very careful about the collective
  37. # operations. The current status is:
  38. # allreduce \ Mode | Eager | Graph |
  39. # --------------------------------------------
  40. # custom allreduce | enabled | enabled |
  41. # PyNccl | disabled| enabled |
  42. # torch.distributed | enabled | disabled|
  43. #
  44. # Note that custom allreduce will have a runtime check, if the tensor
  45. # size is too large, it will fallback to the next available option.
  46. # In summary: When using CUDA graph, we use
  47. # either custom all-reduce kernel or pynccl. When not using CUDA
  48. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
  49. # We always prioritize using custom all-reduce kernel but fall back
  50. # to PyTorch or pynccl if it is disabled or not supported.
  51. tp_pynccl_comm = get_tp_pynccl_communicator()
  52. pp_pynccl_comm = get_pp_pynccl_communicator()
  53. if not tp_pynccl_comm:
  54. maybe_tp_pynccl_context = nullcontext()
  55. else:
  56. maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
  57. enable=True, stream=torch.cuda.current_stream())
  58. if not pp_pynccl_comm:
  59. maybe_pp_pynccl_context = nullcontext()
  60. else:
  61. maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
  62. enable=True, stream=torch.cuda.current_stream())
  63. with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
  64. yield graph_capture_context
  65. def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
  66. """All-reduce the input tensor across model parallel group.
  67. NOTE: This operation will be applied in-place on the input tensor if
  68. disable_custom_all_reduce is set to True. Otherwise, this operation may or
  69. may not be applied in place depending on whether custom all reduce is
  70. invoked for a particular tensor, which further depends on the tensor size
  71. and GPU topology.
  72. TLDR: always assume this function modifies its input, but use the return
  73. value as the output.
  74. """
  75. ca_comm = get_tp_ca_communicator()
  76. # Bypass the function if we are using only 1 GPU.
  77. if get_tensor_model_parallel_world_size() == 1:
  78. return input_
  79. if ca_comm is not None:
  80. out = ca_comm.custom_all_reduce(input_)
  81. if out is not None:
  82. return out
  83. pynccl_comm = get_tp_pynccl_communicator()
  84. if (pynccl_comm is not None and not pynccl_comm.disabled):
  85. pynccl_comm.all_reduce(input_)
  86. else:
  87. torch.distributed.all_reduce(input_,
  88. group=get_tensor_model_parallel_group())
  89. return input_
  90. def tensor_model_parallel_all_gather(input_: torch.Tensor,
  91. dim: int = -1) -> torch.Tensor:
  92. """All-gather the input tensor across model parallel group."""
  93. world_size = get_tensor_model_parallel_world_size()
  94. # Bypass the function if we are using only 1 GPU.
  95. if world_size == 1:
  96. return input_
  97. assert -input_.dim() <= dim < input_.dim(), (
  98. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  99. if dim < 0:
  100. # Convert negative dim to positive.
  101. dim += input_.dim()
  102. input_size = input_.size()
  103. # Allocate output tensor.
  104. output_tensor = torch.empty((world_size, ) + input_size,
  105. dtype=input_.dtype,
  106. device=input_.device)
  107. # All-gather.
  108. torch.distributed.all_gather_into_tensor(
  109. output_tensor, input_, group=get_tensor_model_parallel_group())
  110. # Reshape
  111. output_tensor = output_tensor.movedim(0, dim)
  112. output_tensor = output_tensor.reshape(input_size[:dim] +
  113. (world_size * input_size[dim], ) +
  114. input_size[dim + 1:])
  115. return output_tensor
  116. def tensor_model_parallel_gather(input_: torch.Tensor,
  117. dst: int = 0,
  118. dim: int = -1) -> torch.Tensor:
  119. """Gather the input tensor across model parallel group.
  120. NOTE: We assume that the input tensor is on the same device across
  121. all the ranks.
  122. """
  123. world_size = get_tensor_model_parallel_world_size()
  124. # Bypass the function if we are using only 1 GPU.
  125. if world_size == 1:
  126. return input_
  127. assert -input_.dim() <= dim < input_.dim(), (
  128. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  129. if dim < 0:
  130. # Convert negative dim to positive.
  131. dim += input_.dim()
  132. # Allocate output tensor.
  133. if get_tensor_model_parallel_rank() == dst:
  134. gather_list = [torch.empty_like(input_) for _ in range(world_size)]
  135. else:
  136. gather_list = None
  137. # Gather.
  138. torch.distributed.gather(input_,
  139. gather_list,
  140. dst=dst,
  141. group=get_tensor_model_parallel_group())
  142. if get_tensor_model_parallel_rank() == dst:
  143. output_tensor = torch.cat(gather_list, dim=dim)
  144. else:
  145. output_tensor = None
  146. return output_tensor
  147. def broadcast(input_: torch.Tensor,
  148. src: int = 0,
  149. group: Optional[ProcessGroup] = None):
  150. """Broadcast the input tensor."""
  151. group = group or torch.distributed.group.WORLD
  152. ranks = torch.distributed.get_process_group_ranks(group)
  153. assert src in ranks, f"Invalid src rank ({src})"
  154. # Bypass the function if we are using only 1 GPU.
  155. world_size = torch.distributed.get_world_size(group=group)
  156. if world_size == 1:
  157. return input_
  158. # Broadcast.
  159. torch.distributed.broadcast(input_, src=src, group=group)
  160. return input_
  161. def broadcast_object_list(obj_list: List[Any],
  162. src: int = 0,
  163. group: Optional[ProcessGroup] = None):
  164. """Broadcast the input object list."""
  165. group = group or torch.distributed.group.WORLD
  166. ranks = torch.distributed.get_process_group_ranks(group)
  167. assert src in ranks, f"Invalid src rank ({src})"
  168. # Bypass the function if we are using only 1 GPU.
  169. world_size = torch.distributed.get_world_size(group=group)
  170. if world_size == 1:
  171. return obj_list
  172. # Broadcast.
  173. torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
  174. return obj_list
  175. TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
  176. def _split_tensor_dict(
  177. tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
  178. ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
  179. """Split the tensor dictionary into two parts:
  180. 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
  181. by its metadata.
  182. 2. A list of tensors.
  183. """
  184. metadata_list = []
  185. tensor_list = []
  186. for key, value in tensor_dict.items():
  187. if isinstance(value, torch.Tensor):
  188. # NOTE: we cannot use `value.device` here,
  189. # because it contains not only the device type but also the device
  190. # index (e.g. "cuda:0"). We only need the device type.
  191. # receiving side will set the device index.
  192. device = "cpu" if value.is_cpu else "cuda"
  193. metadata_list.append(
  194. (key, TensorMetadata(device, value.dtype, value.size())))
  195. tensor_list.append(value)
  196. else:
  197. metadata_list.append((key, value))
  198. return metadata_list, tensor_list
  199. def broadcast_tensor_dict(
  200. tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
  201. src: int = 0,
  202. group: Optional[ProcessGroup] = None,
  203. metadata_group: Optional[ProcessGroup] = None
  204. ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
  205. """Broadcast the input tensor dictionary.
  206. `group` is used to broadcast the tensors, while `metadata_group` is used
  207. to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
  208. dtypes).
  209. """
  210. # Bypass the function if we are using only 1 GPU.
  211. if (not torch.distributed.is_initialized()
  212. or torch.distributed.get_world_size(group=group) == 1):
  213. return tensor_dict
  214. group = group or torch.distributed.group.WORLD
  215. metadata_group = metadata_group or get_cpu_world_group()
  216. ranks = torch.distributed.get_process_group_ranks(group)
  217. assert src in ranks, f"Invalid src rank ({src})"
  218. rank = torch.distributed.get_rank()
  219. if rank == src:
  220. metadata_list: List[Tuple[Any, Any]] = []
  221. assert isinstance(
  222. tensor_dict,
  223. dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
  224. metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
  225. # `metadata_list` lives in CPU memory.
  226. # `broadcast_object_list` involves serialization and deserialization,
  227. # all happening on CPU. Therefore, we can use the CPU group.
  228. torch.distributed.broadcast_object_list([metadata_list],
  229. src=src,
  230. group=metadata_group)
  231. async_handles = []
  232. for tensor in tensor_list:
  233. if tensor.numel() == 0:
  234. # Skip broadcasting empty tensors.
  235. continue
  236. if tensor.is_cpu:
  237. # use metadata_group for CPU tensors
  238. handle = torch.distributed.broadcast(tensor,
  239. src=src,
  240. group=metadata_group,
  241. async_op=True)
  242. else:
  243. # use group for GPU tensors
  244. handle = torch.distributed.broadcast(tensor,
  245. src=src,
  246. group=group,
  247. async_op=True)
  248. async_handles.append(handle)
  249. for async_handle in async_handles:
  250. async_handle.wait()
  251. else:
  252. recv_metadata_list = [None]
  253. torch.distributed.broadcast_object_list(recv_metadata_list,
  254. src=src,
  255. group=metadata_group)
  256. assert recv_metadata_list[0] is not None
  257. tensor_dict = {}
  258. async_handles = []
  259. for key, value in recv_metadata_list[0]:
  260. if isinstance(value, TensorMetadata):
  261. tensor = torch.empty(value.size,
  262. dtype=value.dtype,
  263. device=value.device)
  264. if tensor.numel() == 0:
  265. # Skip broadcasting empty tensors.
  266. tensor_dict[key] = tensor
  267. continue
  268. if tensor.is_cpu:
  269. # use metadata_group for CPU tensors
  270. handle = torch.distributed.broadcast(tensor,
  271. src=src,
  272. group=metadata_group,
  273. async_op=True)
  274. else:
  275. # use group for GPU tensors
  276. handle = torch.distributed.broadcast(tensor,
  277. src=src,
  278. group=group,
  279. async_op=True)
  280. async_handles.append(handle)
  281. tensor_dict[key] = tensor
  282. else:
  283. tensor_dict[key] = value
  284. for async_handle in async_handles:
  285. async_handle.wait()
  286. return tensor_dict