|
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The vLLM team.
- # Adapted from
- # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
- """Aphrodite distributed state.
- It takes over the control of the distributed environment from PyTorch.
- The typical workflow is:
- - call `init_distributed_environment` to initialize the distributed environment.
- - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
- initialize the model parallel groups.
- - any code dealing with the distributed stuff
- - call `destroy_model_parallel` to destroy the model parallel groups.
- - call `destroy_distributed_environment` to destroy the distributed environment.
- If you only need to use the distributed environment without model/pipeline
- parallelism, you can skip the model parallel initialization and destruction
- steps.
- """
- import contextlib
- import os
- import pickle
- from collections import namedtuple
- from contextlib import contextmanager, nullcontext
- from dataclasses import dataclass
- from multiprocessing import shared_memory
- from typing import Any, Dict, List, Optional, Tuple, Union
- from unittest.mock import patch
- import torch
- import torch.distributed
- from loguru import logger
- from torch.distributed import Backend, ProcessGroup
- @dataclass
- class GraphCaptureContext:
- stream: torch.cuda.Stream
- TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
- def _split_tensor_dict(
- tensor_dict: Dict[str, Union[torch.Tensor, Any]]
- ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
- """Split the tensor dictionary into two parts:
- 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
- by its metadata.
- 2. A list of tensors.
- """
- metadata_list: List[Tuple[str, Any]] = []
- tensor_list: List[torch.Tensor] = []
- for key, value in tensor_dict.items():
- if isinstance(value, torch.Tensor):
- # Note: we cannot use `value.device` here,
- # because it contains not only the device type but also the device
- # index (e.g. "cuda:0"). We only need the device type.
- # receiving side will set the device index.
- device = value.device.type
- metadata_list.append(
- (key, TensorMetadata(device, value.dtype, value.size())))
- tensor_list.append(value)
- else:
- metadata_list.append((key, value))
- return metadata_list, tensor_list
- class GroupCoordinator:
- """
- PyTorch ProcessGroup wrapper for a group of processes.
- PyTorch ProcessGroup is bound to one specific communication backend,
- e.g. NCCL, Gloo, MPI, etc.
- GroupCoordinator takes charge of all the communication operations among
- the processes in the group. It can route the communication to
- a specific implementation (e.g. switch allreduce implementation
- based on the tensor size and cuda graph mode).
- """
- # available attributes:
- rank: int # global rank
- ranks: List[int] # global ranks in the group
- world_size: int # size of the group
- # difference between `local_rank` and `rank_in_group`:
- # if we have a group of size 4 across two nodes:
- # Process | Node | Rank | Local Rank | Rank in Group
- # 0 | 0 | 0 | 0 | 0
- # 1 | 0 | 1 | 1 | 1
- # 2 | 1 | 2 | 0 | 2
- # 3 | 1 | 3 | 1 | 3
- local_rank: int # local rank used to assign devices
- rank_in_group: int # rank inside the group
- cpu_group: ProcessGroup # group for CPU communication
- device_group: ProcessGroup # group for device communication
- use_pynccl: bool # a hint of whether to use PyNccl
- use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
- # communicators are only created for world size > 1
- pynccl_comm: Optional[Any] # PyNccl communicator
- ca_comm: Optional[Any] # Custom allreduce communicator
- mq_broadcaster: Optional[Any] # shared memory broadcaster
- def __init__(
- self,
- group_ranks: List[List[int]],
- local_rank: int,
- torch_distributed_backend: Union[str, Backend],
- use_pynccl: bool,
- use_custom_allreduce: bool,
- use_tpu_communicator: bool,
- use_message_queue_broadcaster: bool = False,
- ):
- self.rank = torch.distributed.get_rank()
- self.local_rank = local_rank
- self.device_group = None
- self.cpu_group = None
- for ranks in group_ranks:
- device_group = torch.distributed.new_group(
- ranks, backend=torch_distributed_backend)
- # a group with `gloo` backend, to allow direct coordination between
- # processes through the CPU.
- cpu_group = torch.distributed.new_group(ranks, backend="gloo")
- if self.rank in ranks:
- self.ranks = ranks
- self.world_size = len(ranks)
- self.rank_in_group = ranks.index(self.rank)
- self.device_group = device_group
- self.cpu_group = cpu_group
- assert self.cpu_group is not None
- assert self.device_group is not None
- if torch.cuda.is_available():
- self.device = torch.device(f"cuda:{local_rank}")
- else:
- self.device = torch.device("cpu")
- self.use_pynccl = use_pynccl
- self.use_custom_allreduce = use_custom_allreduce
- self.use_tpu_communicator = use_tpu_communicator
- # lazy import to avoid documentation build error
- from aphrodite.distributed.device_communicators.custom_all_reduce import ( # noqa: E501
- CustomAllreduce)
- from aphrodite.distributed.device_communicators.pynccl import (
- PyNcclCommunicator)
- self.pynccl_comm: Optional[PyNcclCommunicator]
- if use_pynccl and self.world_size > 1:
- self.pynccl_comm = PyNcclCommunicator(
- group=self.cpu_group,
- device=self.device,
- )
- else:
- self.pynccl_comm = None
- self.ca_comm: Optional[CustomAllreduce]
- if use_custom_allreduce and self.world_size > 1:
- # Initialize a custom fast all-reduce implementation.
- self.ca_comm = CustomAllreduce(
- group=self.cpu_group,
- device=self.device,
- )
- else:
- self.ca_comm = None
- from aphrodite.distributed.device_communicators.tpu_communicator import ( # noqa: E501
- TpuCommunicator)
- self.tpu_communicator: Optional[TpuCommunicator]
- if use_tpu_communicator and self.world_size > 1:
- self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
- from aphrodite.distributed.device_communicators.shm_broadcast import (
- MessageQueue)
- self.mq_broadcaster: Optional[MessageQueue] = None
- if use_message_queue_broadcaster and self.world_size > 1:
- self.mq_broadcaster = MessageQueue.create_from_process_group(
- self.cpu_group, 1 << 22, 6)
- @property
- def first_rank(self):
- """Return the global rank of the first process in the group"""
- return self.ranks[0]
- @property
- def last_rank(self):
- """Return the global rank of the last process in the group"""
- return self.ranks[-1]
- @property
- def is_first_rank(self):
- """Return whether the caller is the first process in the group"""
- return self.rank == self.first_rank
- @property
- def is_last_rank(self):
- """Return whether the caller is the last process in the group"""
- return self.rank == self.last_rank
- @property
- def next_rank(self):
- """Return the global rank of the process that follows the caller"""
- rank_in_group = self.rank_in_group
- world_size = self.world_size
- return self.ranks[(rank_in_group + 1) % world_size]
- @property
- def prev_rank(self):
- """Return the global rank of the process that precedes the caller"""
- rank_in_group = self.rank_in_group
- world_size = self.world_size
- return self.ranks[(rank_in_group - 1) % world_size]
- @contextmanager
- def graph_capture(
- self, graph_capture_context: Optional[GraphCaptureContext] = None):
- if graph_capture_context is None:
- stream = torch.cuda.Stream()
- graph_capture_context = GraphCaptureContext(stream)
- else:
- stream = graph_capture_context.stream
- ca_comm = self.ca_comm
- maybe_ca_context = nullcontext(
- ) if ca_comm is None else ca_comm.capture()
- # ensure all initialization operations complete before attempting to
- # capture the graph on another stream
- curr_stream = torch.cuda.current_stream()
- if curr_stream != stream:
- stream.wait_stream(curr_stream)
- with torch.cuda.stream(stream), maybe_ca_context:
- # In graph mode, we have to be very careful about the collective
- # operations. The current status is:
- # allreduce \ Mode | Eager | Graph |
- # --------------------------------------------
- # custom allreduce | enabled | enabled |
- # PyNccl | disabled| enabled |
- # torch.distributed | enabled | disabled|
- #
- # Note that custom allreduce will have a runtime check, if the
- # tensor size is too large, it will fallback to the next
- # available option.
- # In summary: When using CUDA graph, we use
- # either custom all-reduce kernel or pynccl. When not using
- # CUDA graph, we use either custom all-reduce kernel or
- # PyTorch NCCL. We always prioritize using custom all-reduce
- # kernel but fall back to PyTorch or pynccl if it is
- # disabled or not supported.
- pynccl_comm = self.pynccl_comm
- maybe_pynccl_context: Any
- if not pynccl_comm:
- maybe_pynccl_context = nullcontext()
- else:
- maybe_pynccl_context = pynccl_comm.change_state(
- enable=True, stream=torch.cuda.current_stream())
- with maybe_pynccl_context:
- yield graph_capture_context
- def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
- """
- NOTE: This operation will be applied in-place or out-of-place.
- Always assume this function modifies its input, but use the return
- value as the output.
- """
- ca_comm = self.ca_comm
- # Bypass the function if we are using only 1 GPU.
- if self.world_size == 1:
- return input_
- # For TPUs, use TPU communicator.
- tpu_comm = self.tpu_communicator
- if tpu_comm is not None and not tpu_comm.disabled:
- return tpu_comm.all_reduce(input_)
- if ca_comm is not None:
- out = ca_comm.custom_all_reduce(input_)
- if out is not None:
- return out
- pynccl_comm = self.pynccl_comm
- if (pynccl_comm is not None and not pynccl_comm.disabled):
- pynccl_comm.all_reduce(input_)
- elif input_.is_cpu:
- import intel_extension_for_pytorch as ipex
- ipex.distributed.all_reduce(input_, group=self.device_group)
- else:
- torch.distributed.all_reduce(input_, group=self.device_group)
- return input_
- def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
- world_size = self.world_size
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
- assert -input_.dim() <= dim < input_.dim(), (
- f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
- # For TPUs, use TPU communicator.
- tpu_comm = self.tpu_communicator
- if tpu_comm is not None and not tpu_comm.disabled:
- return tpu_comm.all_gather(input_, dim)
- if dim < 0:
- # Convert negative dim to positive.
- dim += input_.dim()
- input_size = input_.size()
- # Allocate output tensor.
- output_tensor = torch.empty((world_size, ) + input_size,
- dtype=input_.dtype,
- device=input_.device)
- # All-gather.
- torch.distributed.all_gather_into_tensor(output_tensor,
- input_,
- group=self.device_group)
- # Reshape
- output_tensor = output_tensor.movedim(0, dim)
- output_tensor = output_tensor.reshape(input_size[:dim] +
- (world_size *
- input_size[dim], ) +
- input_size[dim + 1:])
- return output_tensor
- def gather(self,
- input_: torch.Tensor,
- dst: int = 0,
- dim: int = -1) -> torch.Tensor:
- """
- NOTE: We assume that the input tensor is on the same device across
- all the ranks.
- NOTE: `dst` is the local rank of the destination rank.
- """
- world_size = self.world_size
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
- assert -input_.dim() <= dim < input_.dim(), (
- f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
- if dim < 0:
- # Convert negative dim to positive.
- dim += input_.dim()
- # Allocate output tensor.
- if self.rank_in_group == dst:
- gather_list = [torch.empty_like(input_) for _ in range(world_size)]
- else:
- gather_list = None
- # Gather.
- torch.distributed.gather(input_,
- gather_list,
- dst=self.ranks[dst],
- group=self.device_group)
- if self.rank_in_group == dst:
- output_tensor = torch.cat(gather_list, dim=dim)
- else:
- output_tensor = None
- return output_tensor
- def broadcast(self, input_: torch.Tensor, src: int = 0):
- """Broadcast the input tensor.
- NOTE: `src` is the local rank of the source rank.
- """
- assert src < self.world_size, f"Invalid src rank ({src})"
- # Bypass the function if we are using only 1 GPU.
- if self.world_size == 1:
- return input_
- # Broadcast.
- torch.distributed.broadcast(input_,
- src=self.ranks[src],
- group=self.device_group)
- return input_
- def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
- """Broadcast the input object.
- NOTE: `src` is the local rank of the source rank.
- """
- assert src < self.world_size, f"Invalid src rank ({src})"
- # Bypass the function if we are using only 1 GPU.
- if self.world_size == 1:
- return obj
- if self.mq_broadcaster is not None:
- assert src == 0, "Message queue broadcaster only supports src=0"
- return self.mq_broadcaster.broadcast_object(obj)
- if self.rank_in_group == src:
- torch.distributed.broadcast_object_list([obj],
- src=self.ranks[src],
- group=self.cpu_group)
- return obj
- else:
- recv = [None]
- torch.distributed.broadcast_object_list(recv,
- src=self.ranks[src],
- group=self.cpu_group)
- return recv[0]
- def broadcast_object_list(self,
- obj_list: List[Any],
- src: int = 0,
- group: Optional[ProcessGroup] = None):
- """Broadcast the input object list.
- NOTE: `src` is the local rank of the source rank.
- """
- assert src < self.world_size, f"Invalid src rank ({src})"
- # Bypass the function if we are using only 1 GPU.
- if self.world_size == 1:
- return obj_list
- # Broadcast.
- torch.distributed.broadcast_object_list(obj_list,
- src=self.ranks[src],
- group=self.device_group)
- return obj_list
- def send_object(self, obj: Any, dst: int) -> None:
- """Send the input object list to the destination rank."""
- """NOTE: `dst` is the local rank of the destination rank."""
- assert dst < self.world_size, f"Invalid dst rank ({dst})"
- assert dst != self.rank_in_group, (
- "Invalid destination rank. Destination rank is the same "
- "as the current rank.")
- # Serialize object to tensor and get the size as well
- object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
- size_tensor = torch.tensor([object_tensor.numel()],
- dtype=torch.long,
- device="cpu")
- # Send object size
- torch.distributed.send(size_tensor,
- dst=self.ranks[dst],
- group=self.cpu_group)
- # Send object
- torch.distributed.send(object_tensor,
- dst=self.ranks[dst],
- group=self.cpu_group)
- return None
- def recv_object(self, src: int) -> Any:
- """Receive the input object list from the source rank."""
- """NOTE: `src` is the local rank of the source rank."""
- assert src < self.world_size, f"Invalid src rank ({src})"
- assert src != self.rank_in_group, (
- "Invalid source rank. Source rank is the same as the current rank."
- )
- size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
- # Receive object size
- rank_size = torch.distributed.recv(size_tensor,
- src=self.ranks[src],
- group=self.cpu_group)
- # Tensor to receive serialized objects into.
- object_tensor = torch.empty( # type: ignore[call-overload]
- size_tensor.item(), # type: ignore[arg-type]
- dtype=torch.uint8,
- device="cpu")
- rank_object = torch.distributed.recv(object_tensor,
- src=self.ranks[src],
- group=self.cpu_group)
- assert rank_object == rank_size, (
- "Received object sender rank does not match the size sender rank.")
- obj = pickle.loads(object_tensor.numpy().tobytes())
- return obj
- def broadcast_tensor_dict(
- self,
- tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
- src: int = 0,
- group: Optional[ProcessGroup] = None,
- metadata_group: Optional[ProcessGroup] = None
- ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
- """Broadcast the input tensor dictionary.
- NOTE: `src` is the local rank of the source rank.
- """
- # Bypass the function if we are using only 1 GPU.
- if (not torch.distributed.is_initialized() or self.world_size == 1):
- return tensor_dict
- group = self.device_group
- metadata_group = self.cpu_group
- assert src < self.world_size, f"Invalid src rank ({src})"
- rank_in_group = self.rank_in_group
- if rank_in_group == src:
- metadata_list: List[Tuple[Any, Any]] = []
- assert isinstance(
- tensor_dict,
- dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
- metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
- # `metadata_list` lives in CPU memory.
- # `broadcast_object_list` has serialization & deserialization,
- # all happening on CPU. Therefore, we can use the CPU group.
- self.broadcast_object(metadata_list, src=src)
- async_handles = []
- for tensor in tensor_list:
- if tensor.numel() == 0:
- # Skip broadcasting empty tensors.
- continue
- if tensor.is_cpu:
- # use metadata_group for CPU tensors
- handle = torch.distributed.broadcast(tensor,
- src=self.ranks[src],
- group=metadata_group,
- async_op=True)
- else:
- # use group for GPU tensors
- handle = torch.distributed.broadcast(tensor,
- src=self.ranks[src],
- group=group,
- async_op=True)
- async_handles.append(handle)
- for async_handle in async_handles:
- async_handle.wait()
- else:
- metadata_list = self.broadcast_object(None, src=src)
- tensor_dict = {}
- async_handles = []
- for key, value in metadata_list:
- if isinstance(value, TensorMetadata):
- tensor = torch.empty(value.size,
- dtype=value.dtype,
- device=value.device)
- if tensor.numel() == 0:
- # Skip broadcasting empty tensors.
- tensor_dict[key] = tensor
- continue
- if tensor.is_cpu:
- # use metadata_group for CPU tensors
- handle = torch.distributed.broadcast(
- tensor,
- src=self.ranks[src],
- group=metadata_group,
- async_op=True)
- else:
- # use group for GPU tensors
- handle = torch.distributed.broadcast(
- tensor,
- src=self.ranks[src],
- group=group,
- async_op=True)
- async_handles.append(handle)
- tensor_dict[key] = tensor
- else:
- tensor_dict[key] = value
- for async_handle in async_handles:
- async_handle.wait()
- return tensor_dict
- def send_tensor_dict(
- self,
- tensor_dict: Dict[str, Union[torch.Tensor, Any]],
- dst: Optional[int] = None,
- all_gather_group: Optional["GroupCoordinator"] = None,
- ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
- """Send the input tensor dictionary.
- NOTE: `dst` is the local rank of the source rank.
- """
- # Bypass the function if we are using only 1 GPU.
- if not torch.distributed.is_initialized() or self.world_size == 1:
- return tensor_dict
- all_gather_size = (1 if all_gather_group is None else
- all_gather_group.world_size)
- all_gather_rank = (0 if all_gather_group is None else
- all_gather_group.rank_in_group)
- group = self.device_group
- metadata_group = self.cpu_group
- if dst is None:
- dst = (self.rank_in_group + 1) % self.world_size
- assert dst < self.world_size, f"Invalid dst rank ({dst})"
- metadata_list: List[Tuple[Any, Any]] = []
- assert isinstance(
- tensor_dict,
- dict), f"Expecting a dictionary, got {type(tensor_dict)}"
- metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
- # `metadata_list` lives in CPU memory.
- # `send_object_list` has serialization & deserialization,
- # all happening on CPU. Therefore, we can use the CPU group.
- self.send_object(metadata_list, dst=dst)
- for tensor in tensor_list:
- if tensor.numel() == 0:
- # Skip sending empty tensors.
- continue
- # send-allgather: send only a slice, then do allgather.
- if (all_gather_group is not None
- and tensor.numel() % all_gather_size == 0):
- tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
- if tensor.is_cpu:
- # use metadata_group for CPU tensors
- torch.distributed.send(tensor,
- dst=self.ranks[dst],
- group=metadata_group)
- else:
- # use group for GPU tensors
- torch.distributed.send(tensor,
- dst=self.ranks[dst],
- group=group)
- return None
- def recv_tensor_dict(
- self,
- src: Optional[int] = None,
- all_gather_group: Optional["GroupCoordinator"] = None,
- ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
- """Recv the input tensor dictionary.
- NOTE: `src` is the local rank of the source rank.
- """
- # Bypass the function if we are using only 1 GPU.
- if not torch.distributed.is_initialized() or self.world_size == 1:
- return None
- all_gather_size = (1 if all_gather_group is None else
- all_gather_group.world_size)
- all_gather_rank = (0 if all_gather_group is None else
- all_gather_group.rank_in_group)
- group = self.device_group
- metadata_group = self.cpu_group
- if src is None:
- src = (self.rank_in_group - 1) % self.world_size
- assert src < self.world_size, f"Invalid src rank ({src})"
- recv_metadata_list = self.recv_object(src=src)
- tensor_dict: Dict[str, Any] = {}
- for key, value in recv_metadata_list:
- if isinstance(value, TensorMetadata):
- tensor = torch.empty(value.size,
- dtype=value.dtype,
- device=value.device)
- if tensor.numel() == 0:
- # Skip broadcasting empty tensors.
- tensor_dict[key] = tensor
- continue
- # send-allgather: send only a slice, then do allgather.
- use_all_gather = (all_gather_group is not None
- and tensor.numel() % all_gather_size == 0)
- if use_all_gather:
- orig_shape = tensor.shape
- tensor = tensor.reshape(all_gather_size,
- -1)[all_gather_rank]
- if tensor.is_cpu:
- # use metadata_group for CPU tensors
- torch.distributed.recv(tensor,
- src=self.ranks[src],
- group=metadata_group)
- else:
- # use group for GPU tensors
- torch.distributed.recv(tensor,
- src=self.ranks[src],
- group=group)
- if use_all_gather:
- # do the allgather
- tensor = all_gather_group.all_gather( # type: ignore
- tensor, dim=0)
- tensor = tensor.reshape(orig_shape)
- tensor_dict[key] = tensor
- else:
- tensor_dict[key] = value
- return tensor_dict
- def barrier(self):
- """Barrier synchronization among the group.
- NOTE: don't use `device_group` here! `barrier` in NCCL is
- terrible because it is internally a broadcast operation with
- secretly created GPU tensors. It is easy to mess up the current
- device. Use the CPU group instead.
- """
- torch.distributed.barrier(group=self.cpu_group)
- def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
- """Sends a tensor to the destination rank in a non-blocking way"""
- """NOTE: `dst` is the local rank of the destination rank."""
- if dst is None:
- dst = (self.rank_in_group + 1) % self.world_size
- pynccl_comm = self.pynccl_comm
- if pynccl_comm is not None and not pynccl_comm.disabled:
- pynccl_comm.send(tensor, dst)
- else:
- torch.distributed.send(tensor, self.ranks[dst], self.device_group)
- def recv(self,
- size: torch.Size,
- dtype: torch.dtype,
- src: Optional[int] = None) -> torch.Tensor:
- """Receives a tensor from the src rank."""
- """NOTE: `src` is the local rank of the destination rank."""
- if src is None:
- src = (self.rank_in_group - 1) % self.world_size
- tensor = torch.empty(size, dtype=dtype, device=self.device)
- pynccl_comm = self.pynccl_comm
- if pynccl_comm is not None and not pynccl_comm.disabled:
- pynccl_comm.recv(tensor, src)
- else:
- torch.distributed.recv(tensor, self.ranks[src], self.device_group)
- return tensor
- def destroy(self):
- if self.device_group is not None:
- torch.distributed.destroy_process_group(self.device_group)
- self.device_group = None
- if self.cpu_group is not None:
- torch.distributed.destroy_process_group(self.cpu_group)
- self.cpu_group = None
- if self.pynccl_comm is not None:
- self.pynccl_comm = None
- if self.ca_comm is not None:
- self.ca_comm = None
- if self.mq_broadcaster is not None:
- self.mq_broadcaster = None
- _WORLD: Optional[GroupCoordinator] = None
- def get_world_group() -> GroupCoordinator:
- assert _WORLD is not None, ("world group is not initialized")
- return _WORLD
- def init_world_group(ranks: List[int], local_rank: int,
- backend: str) -> GroupCoordinator:
- return GroupCoordinator(
- group_ranks=[ranks],
- local_rank=local_rank,
- torch_distributed_backend=backend,
- use_pynccl=False,
- use_custom_allreduce=False,
- use_tpu_communicator=False,
- )
- def init_model_parallel_group(
- group_ranks: List[List[int]],
- local_rank: int,
- backend: str,
- use_custom_allreduce: Optional[bool] = None,
- use_message_queue_broadcaster: bool = False,
- ) -> GroupCoordinator:
- if use_custom_allreduce is None:
- use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
- return GroupCoordinator(
- group_ranks=group_ranks,
- local_rank=local_rank,
- torch_distributed_backend=backend,
- use_pynccl=True,
- use_custom_allreduce=use_custom_allreduce,
- use_tpu_communicator=True,
- use_message_queue_broadcaster=use_message_queue_broadcaster,
- )
- _TP: Optional[GroupCoordinator] = None
- def get_tp_group() -> GroupCoordinator:
- assert _TP is not None, ("tensor model parallel group is not initialized")
- return _TP
- # kept for backward compatibility
- get_tensor_model_parallel_group = get_tp_group
- _PP: Optional[GroupCoordinator] = None
- def get_pp_group() -> GroupCoordinator:
- assert _PP is not None, (
- "pipeline model parallel group is not initialized")
- return _PP
- # kept for backward compatibility
- get_pipeline_model_parallel_group = get_pp_group
- @contextmanager
- def graph_capture():
- """
- `graph_capture` is a context manager which should surround the code that
- is capturing the CUDA graph. Its main purpose is to ensure that the
- some operations will be run after the graph is captured, before the graph
- is replayed. It returns a `GraphCaptureContext` object which contains the
- necessary data for the graph capture. Currently, it only contains the
- stream that the graph capture is running on. This stream is set to the
- current CUDA stream when the context manager is entered and reset to the
- default stream when the context manager is exited. This is to ensure that
- the graph capture is running on a separate stream from the default stream,
- in order to explicitly distinguish the kernels to capture
- from other kernels possibly launched on background in the default stream.
- """
- with get_tp_group().graph_capture() as context, get_pp_group(
- ).graph_capture(context):
- yield context
- _ENABLE_CUSTOM_ALL_REDUCE = True
- def set_custom_all_reduce(enable: bool):
- global _ENABLE_CUSTOM_ALL_REDUCE
- _ENABLE_CUSTOM_ALL_REDUCE = enable
- def init_distributed_environment(
- world_size: int = -1,
- rank: int = -1,
- distributed_init_method: str = "env://",
- local_rank: int = -1,
- backend: str = "nccl",
- ):
- logger.debug(
- f"world_size={world_size} rank={rank} local_rank={local_rank} "
- f"distributed_init_method={distributed_init_method} backend={backend}")
- if not torch.distributed.is_initialized():
- assert distributed_init_method is not None, (
- "distributed_init_method must be provided when initializing "
- "distributed environment")
- # this backend is used for WORLD
- torch.distributed.init_process_group(
- backend=backend,
- init_method=distributed_init_method,
- world_size=world_size,
- rank=rank)
- # set the local rank
- # local_rank is not available in torch ProcessGroup,
- # see https://github.com/pytorch/pytorch/issues/122816
- if local_rank == -1:
- # local rank not set, this usually happens in single-node
- # setting, where we can use rank as local rank
- if distributed_init_method == "env://":
- local_rank = os.getenv("LOCAL_RANK", rank)
- else:
- local_rank = rank
- global _WORLD
- if _WORLD is None:
- ranks = list(range(torch.distributed.get_world_size()))
- _WORLD = init_world_group(ranks, local_rank, backend)
- else:
- assert _WORLD.world_size == torch.distributed.get_world_size(), (
- "world group already initialized with a different world size")
- def initialize_model_parallel(
- tensor_model_parallel_size: int = 1,
- pipeline_model_parallel_size: int = 1,
- backend: Optional[str] = None,
- ) -> None:
- """
- Initialize model parallel groups.
- Arguments:
- tensor_model_parallel_size: number of GPUs used for tensor model
- parallelism.
- pipeline_model_parallel_size: number of GPUs used for pipeline model
- parallelism.
- Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
- use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
- the model pipeline. The present function will
- create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
- 4 tensor model-parallel groups:
- [g0, g1], [g2, g3], [g4, g5], [g6, g7]
- 2 pipeline model-parallel groups:
- [g0, g2, g4, g6], [g1, g3, g5, g7]
- Note that for efficiency, the caller should make sure adjacent ranks
- are on the same DGX box. For example if we are using 2 DGX-1 boxes
- with a total of 16 GPUs, rank 0 to 7 belong to the first box and
- ranks 8 to 15 belong to the second box.
- """
- # Get world size and rank. Ensure some consistencies.
- assert torch.distributed.is_initialized()
- world_size: int = torch.distributed.get_world_size()
- backend = backend or torch.distributed.get_backend(
- get_world_group().device_group)
- if (world_size !=
- tensor_model_parallel_size * pipeline_model_parallel_size):
- raise RuntimeError(
- f"world_size ({world_size}) is not equal to "
- f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
- f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
- # Build the tensor model-parallel groups.
- num_tensor_model_parallel_groups: int = (world_size //
- tensor_model_parallel_size)
- global _TP
- assert _TP is None, ("tensor model parallel group is already initialized")
- group_ranks = []
- for i in range(num_tensor_model_parallel_groups):
- ranks = list(
- range(i * tensor_model_parallel_size,
- (i + 1) * tensor_model_parallel_size))
- group_ranks.append(ranks)
- # message queue broadcaster is only used in tensor model parallel group
- _TP = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- use_message_queue_broadcaster=True)
- # Build the pipeline model-parallel groups.
- num_pipeline_model_parallel_groups: int = (world_size //
- pipeline_model_parallel_size)
- global _PP
- assert _PP is None, (
- "pipeline model parallel group is already initialized")
- group_ranks = []
- for i in range(num_pipeline_model_parallel_groups):
- ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
- group_ranks.append(ranks)
- # pipeline parallel does not need custom allreduce
- _PP = init_model_parallel_group(group_ranks,
- get_world_group().local_rank,
- backend,
- use_custom_allreduce=False)
- def ensure_model_parallel_initialized(
- tensor_model_parallel_size: int,
- pipeline_model_parallel_size: int,
- backend: Optional[str] = None,
- ) -> None:
- """Helper to initialize model parallel groups if they are not initialized,
- or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
- values if the model parallel groups are initialized.
- """
- backend = backend or torch.distributed.get_backend(
- get_world_group().device_group)
- if not model_parallel_is_initialized():
- initialize_model_parallel(tensor_model_parallel_size,
- pipeline_model_parallel_size, backend)
- return
- assert (
- get_tensor_model_parallel_world_size() == tensor_model_parallel_size
- ), ("tensor parallel group already initialized, but of unexpected size: "
- f"{get_tensor_model_parallel_world_size()=} vs. "
- f"{tensor_model_parallel_size=}")
- pp_world_size = get_pp_group().world_size
- assert (pp_world_size == pipeline_model_parallel_size), (
- "pipeline parallel group already initialized, but of unexpected size: "
- f"{pp_world_size=} vs. "
- f"{pipeline_model_parallel_size=}")
- def model_parallel_is_initialized():
- """Check if tensor and pipeline parallel groups are initialized."""
- return (_TP is not None and _PP is not None)
- _TP_STATE_PATCHED = False
- @contextmanager
- def patch_tensor_parallel_group(tp_group: GroupCoordinator):
- """Patch the tp group temporarily until this function ends.
- This method is for draft workers of speculative decoding to run draft model
- with different tp degree from that of target model workers.
- Args:
- tp_group (GroupCoordinator): the tp group coordinator
- """
- global _TP_STATE_PATCHED
- assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
- _TP_STATE_PATCHED = True
- old_tp_group = get_tp_group()
- global _TP
- _TP = tp_group
- try:
- yield
- finally:
- # restore the original state
- _TP_STATE_PATCHED = False
- _TP = old_tp_group
- def get_tensor_model_parallel_world_size():
- """Return world size for the tensor model parallel group."""
- return get_tp_group().world_size
- def get_tensor_model_parallel_rank():
- """Return my rank for the tensor model parallel group."""
- return get_tp_group().rank_in_group
- def destroy_model_parallel():
- """Set the groups to none and destroy them."""
- global _TP
- if _TP:
- _TP.destroy()
- _TP = None
- global _PP
- if _PP:
- _PP.destroy()
- _PP = None
- def destroy_distributed_environment():
- global _WORLD
- if _WORLD:
- _WORLD.destroy()
- _WORLD = None
- if torch.distributed.is_initialized():
- torch.distributed.destroy_process_group()
- def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
- """
- This is a collective operation that returns if each rank is in the same node
- as the source rank. It tests if processes are attached to the same
- memory system (shared access to shared memory).
- """
- assert torch.distributed.get_backend(
- pg) != torch.distributed.Backend.NCCL, (
- "in_the_same_node_as should be tested with a non-NCCL group.")
- # local rank inside the group
- rank = torch.distributed.get_rank(group=pg)
- world_size = torch.distributed.get_world_size(group=pg)
- # local tensor in each process to store the result
- is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
- # global ranks of the processes in the group
- ranks = torch.distributed.get_process_group_ranks(pg)
- magic_message = b"magic_message"
- shm = None
- try:
- with contextlib.suppress(OSError):
- if rank == source_rank:
- # create a shared memory segment
- shm = shared_memory.SharedMemory(create=True, size=128)
- shm.buf[:len(magic_message)] = magic_message
- torch.distributed.broadcast_object_list([shm.name],
- src=ranks[source_rank],
- group=pg)
- is_in_the_same_node[rank] = 1
- else:
- # try to open the shared memory segment
- recv = [None]
- torch.distributed.broadcast_object_list(recv,
- src=ranks[source_rank],
- group=pg)
- name = recv[0]
- # fix to https://stackoverflow.com/q/62748654/9191338
- # Python incorrectly tracks shared memory even if it is not
- # created by the process. The following patch is a workaround.
- with patch("multiprocessing.resource_tracker.register",
- lambda *args, **kwargs: None):
- shm = shared_memory.SharedMemory(name=name)
- if shm.buf[:len(magic_message)] == magic_message:
- is_in_the_same_node[rank] = 1
- except Exception as e:
- logger.error(f"Error ignored in is_in_the_same_node: {e}")
- finally:
- if shm:
- shm.close()
- torch.distributed.barrier(group=pg)
- # clean up the shared memory segment
- with contextlib.suppress(OSError):
- if rank == source_rank and shm:
- shm.unlink()
- torch.distributed.all_reduce(is_in_the_same_node, group=pg)
- return [x == 1 for x in is_in_the_same_node.tolist()]
- def get_current_tp_rank_partition_offset(total_size: int,
- tp_rank: Optional[int] = None,
- tp_size: Optional[int] = None,
- multiple_of: int = 1) -> int:
- if tp_rank is None:
- tp_rank = get_tensor_model_parallel_rank()
- if tp_size is None:
- tp_size = get_tensor_model_parallel_world_size()
- assert total_size % multiple_of == 0
- total_size = total_size // multiple_of
- return ((total_size // tp_size) * tp_rank +
- min(total_size % tp_size, tp_rank)) * multiple_of
- def get_current_tp_rank_partition_size(total_size: int,
- tp_rank: Optional[int] = None,
- tp_size: Optional[int] = None,
- multiple_of: int = 1) -> int:
- if tp_rank is None:
- tp_rank = get_tensor_model_parallel_rank()
- if tp_size is None:
- tp_size = get_tensor_model_parallel_world_size()
- assert total_size % multiple_of == 0
- total_size = total_size // multiple_of
- return ((total_size // tp_size) +
- (total_size % tp_size > tp_rank)) * multiple_of
|