123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- from collections import namedtuple
- from contextlib import contextmanager, nullcontext
- from dataclasses import dataclass
- from typing import Any, Dict, List, Optional, Tuple, Union
- import torch
- from torch.distributed import ProcessGroup
- from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
- get_tensor_model_parallel_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- get_tp_ca_communicator,
- get_tp_pynccl_communicator)
- @dataclass
- class GraphCaptureContext:
- stream: torch.cuda.Stream
- @contextmanager
- def graph_capture():
- """
- `graph_capture` is a context manager which should surround the code that
- is capturing the CUDA graph. Its main purpose is to ensure that the
- some operations will be run after the graph is captured, before the graph
- is replayed. It returns a `GraphCaptureContext` object which contains the
- necessary data for the graph capture. Currently, it only contains the
- stream that the graph capture is running on. This stream is set to the
- current CUDA stream when the context manager is entered and reset to the
- default stream when the context manager is exited. This is to ensure that
- the graph capture is running on a separate stream from the default stream,
- in order to explicitly distinguish the kernels to capture
- from other kernels possibly launched on background in the default stream.
- """
- stream = torch.cuda.Stream()
- graph_capture_context = GraphCaptureContext(stream)
- ca_comm = get_tp_ca_communicator()
- maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
- with torch.cuda.stream(stream), maybe_ca_context:
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- tp_pynccl_comm = get_tp_pynccl_communicator()
- pp_pynccl_comm = get_pp_pynccl_communicator()
- if not tp_pynccl_comm:
- maybe_tp_pynccl_context = nullcontext()
- else:
- maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
- enable=True, stream=torch.cuda.current_stream())
- if not pp_pynccl_comm:
- maybe_pp_pynccl_context = nullcontext()
- else:
- maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
- enable=True, stream=torch.cuda.current_stream())
- with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
- yield graph_capture_context
- def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
- """All-reduce the input tensor across model parallel group.
- NOTE: This operation will be applied in-place on the input tensor if
- disable_custom_all_reduce is set to True. Otherwise, this operation may or
- may not be applied in place depending on whether custom all reduce is
- invoked for a particular tensor, which further depends on the tensor size
- and GPU topology.
- TLDR: always assume this function modifies its input, but use the return
- value as the output.
- """
- ca_comm = get_tp_ca_communicator()
-
- if get_tensor_model_parallel_world_size() == 1:
- return input_
- if ca_comm is not None:
- out = ca_comm.custom_all_reduce(input_)
- if out is not None:
- return out
- pynccl_comm = get_tp_pynccl_communicator()
- if (pynccl_comm is not None and not pynccl_comm.disabled):
- pynccl_comm.all_reduce(input_)
- else:
- torch.distributed.all_reduce(input_,
- group=get_tensor_model_parallel_group())
- return input_
- def tensor_model_parallel_all_gather(input_: torch.Tensor,
- dim: int = -1) -> torch.Tensor:
- """All-gather the input tensor across model parallel group."""
- world_size = get_tensor_model_parallel_world_size()
-
- if world_size == 1:
- return input_
- assert -input_.dim() <= dim < input_.dim(), (
- f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
- if dim < 0:
-
- dim += input_.dim()
- input_size = input_.size()
-
- output_tensor = torch.empty((world_size, ) + input_size,
- dtype=input_.dtype,
- device=input_.device)
-
- torch.distributed.all_gather_into_tensor(
- output_tensor, input_, group=get_tensor_model_parallel_group())
-
- output_tensor = output_tensor.movedim(0, dim)
- output_tensor = output_tensor.reshape(input_size[:dim] +
- (world_size * input_size[dim], ) +
- input_size[dim + 1:])
- return output_tensor
- def tensor_model_parallel_gather(input_: torch.Tensor,
- dst: int = 0,
- dim: int = -1) -> torch.Tensor:
- """Gather the input tensor across model parallel group.
- NOTE: We assume that the input tensor is on the same device across
- all the ranks.
- """
- world_size = get_tensor_model_parallel_world_size()
-
- if world_size == 1:
- return input_
- assert -input_.dim() <= dim < input_.dim(), (
- f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
- if dim < 0:
-
- dim += input_.dim()
-
- if get_tensor_model_parallel_rank() == dst:
- gather_list = [torch.empty_like(input_) for _ in range(world_size)]
- else:
- gather_list = None
-
- torch.distributed.gather(input_,
- gather_list,
- dst=dst,
- group=get_tensor_model_parallel_group())
- if get_tensor_model_parallel_rank() == dst:
- output_tensor = torch.cat(gather_list, dim=dim)
- else:
- output_tensor = None
- return output_tensor
- def broadcast(input_: torch.Tensor,
- src: int = 0,
- group: Optional[ProcessGroup] = None):
- """Broadcast the input tensor."""
- group = group or torch.distributed.group.WORLD
- ranks = torch.distributed.get_process_group_ranks(group)
- assert src in ranks, f"Invalid src rank ({src})"
-
- world_size = torch.distributed.get_world_size(group=group)
- if world_size == 1:
- return input_
-
- torch.distributed.broadcast(input_, src=src, group=group)
- return input_
- def broadcast_object_list(obj_list: List[Any],
- src: int = 0,
- group: Optional[ProcessGroup] = None):
- """Broadcast the input object list."""
- group = group or torch.distributed.group.WORLD
- ranks = torch.distributed.get_process_group_ranks(group)
- assert src in ranks, f"Invalid src rank ({src})"
-
- world_size = torch.distributed.get_world_size(group=group)
- if world_size == 1:
- return obj_list
-
- torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
- return obj_list
- TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
- def _split_tensor_dict(
- tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
- ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
- """Split the tensor dictionary into two parts:
- 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
- by its metadata.
- 2. A list of tensors.
- """
- metadata_list = []
- tensor_list = []
- for key, value in tensor_dict.items():
- if isinstance(value, torch.Tensor):
-
-
-
-
- device = "cpu" if value.is_cpu else "cuda"
- metadata_list.append(
- (key, TensorMetadata(device, value.dtype, value.size())))
- tensor_list.append(value)
- else:
- metadata_list.append((key, value))
- return metadata_list, tensor_list
- def broadcast_tensor_dict(
- tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
- src: int = 0,
- group: Optional[ProcessGroup] = None,
- metadata_group: Optional[ProcessGroup] = None
- ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
- """Broadcast the input tensor dictionary.
- `group` is used to broadcast the tensors, while `metadata_group` is used
- to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
- dtypes).
- """
-
- if (not torch.distributed.is_initialized()
- or torch.distributed.get_world_size(group=group) == 1):
- return tensor_dict
- group = group or torch.distributed.group.WORLD
- metadata_group = metadata_group or get_cpu_world_group()
- ranks = torch.distributed.get_process_group_ranks(group)
- assert src in ranks, f"Invalid src rank ({src})"
- rank = torch.distributed.get_rank()
- if rank == src:
- metadata_list: List[Tuple[Any, Any]] = []
- assert isinstance(
- tensor_dict,
- dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
- metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
-
-
-
- torch.distributed.broadcast_object_list([metadata_list],
- src=src,
- group=metadata_group)
- async_handles = []
- for tensor in tensor_list:
- if tensor.numel() == 0:
-
- continue
- if tensor.is_cpu:
-
- handle = torch.distributed.broadcast(tensor,
- src=src,
- group=metadata_group,
- async_op=True)
- else:
-
- handle = torch.distributed.broadcast(tensor,
- src=src,
- group=group,
- async_op=True)
- async_handles.append(handle)
- for async_handle in async_handles:
- async_handle.wait()
- else:
- recv_metadata_list = [None]
- torch.distributed.broadcast_object_list(recv_metadata_list,
- src=src,
- group=metadata_group)
- assert recv_metadata_list[0] is not None
- tensor_dict = {}
- async_handles = []
- for key, value in recv_metadata_list[0]:
- if isinstance(value, TensorMetadata):
- tensor = torch.empty(value.size,
- dtype=value.dtype,
- device=value.device)
- if tensor.numel() == 0:
-
- tensor_dict[key] = tensor
- continue
- if tensor.is_cpu:
-
- handle = torch.distributed.broadcast(tensor,
- src=src,
- group=metadata_group,
- async_op=True)
- else:
-
- handle = torch.distributed.broadcast(tensor,
- src=src,
- group=group,
- async_op=True)
- async_handles.append(handle)
- tensor_dict[key] = tensor
- else:
- tensor_dict[key] = value
- for async_handle in async_handles:
- async_handle.wait()
- return tensor_dict
|