|
@@ -3,48 +3,834 @@
|
|
|
# Adapted from
|
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
|
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
|
-"""Tensor and pipeline parallel groups."""
|
|
|
+"""Aphrodite distributed state.
|
|
|
+It takes over the control of the distributed environment from PyTorch.
|
|
|
+The typical workflow is:
|
|
|
+
|
|
|
+- call `init_distributed_environment` to initialize the distributed environment.
|
|
|
+- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
|
|
|
+ initialize the model parallel groups.
|
|
|
+
|
|
|
+- any code dealing with the distributed stuff
|
|
|
+
|
|
|
+- call `destroy_model_parallel` to destroy the model parallel groups.
|
|
|
+- call `destroy_distributed_environment` to destroy the distributed environment.
|
|
|
+
|
|
|
+If you only need to use the distributed environment without model/pipeline
|
|
|
+ parallelism, you can skip the model parallel initialization and destruction
|
|
|
+ steps.
|
|
|
+"""
|
|
|
import contextlib
|
|
|
-from typing import Optional
|
|
|
+import os
|
|
|
+import pickle
|
|
|
+from collections import namedtuple
|
|
|
+from contextlib import contextmanager, nullcontext
|
|
|
+from dataclasses import dataclass
|
|
|
+from multiprocessing import shared_memory
|
|
|
+from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
+from unittest.mock import patch
|
|
|
|
|
|
import torch
|
|
|
+import torch.distributed
|
|
|
from loguru import logger
|
|
|
+from torch.distributed import Backend, ProcessGroup
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class GraphCaptureContext:
|
|
|
+ stream: torch.cuda.Stream
|
|
|
+
|
|
|
+
|
|
|
+TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
|
|
+
|
|
|
+
|
|
|
+def _split_tensor_dict(
|
|
|
+ tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
|
|
+) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
|
|
+ """Split the tensor dictionary into two parts:
|
|
|
+ 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
|
|
+ by its metadata.
|
|
|
+ 2. A list of tensors.
|
|
|
+ """
|
|
|
+ metadata_list: List[Tuple[str, Any]] = []
|
|
|
+ tensor_list: List[torch.Tensor] = []
|
|
|
+ for key, value in tensor_dict.items():
|
|
|
+ if isinstance(value, torch.Tensor):
|
|
|
+ # Note: we cannot use `value.device` here,
|
|
|
+ # because it contains not only the device type but also the device
|
|
|
+ # index (e.g. "cuda:0"). We only need the device type.
|
|
|
+ # receiving side will set the device index.
|
|
|
+ device = value.device.type
|
|
|
+ metadata_list.append(
|
|
|
+ (key, TensorMetadata(device, value.dtype, value.size())))
|
|
|
+ tensor_list.append(value)
|
|
|
+ else:
|
|
|
+ metadata_list.append((key, value))
|
|
|
+ return metadata_list, tensor_list
|
|
|
+
|
|
|
+
|
|
|
+class GroupCoordinator:
|
|
|
+ """
|
|
|
+ PyTorch ProcessGroup wrapper for a group of processes.
|
|
|
+ PyTorch ProcessGroup is bound to one specific communication backend,
|
|
|
+ e.g. NCCL, Gloo, MPI, etc.
|
|
|
+ GroupCoordinator takes charge of all the communication operations among
|
|
|
+ the processes in the group. It can route the communication to
|
|
|
+ a specific implementation (e.g. switch allreduce implementation
|
|
|
+ based on the tensor size and cuda graph mode).
|
|
|
+ """
|
|
|
+
|
|
|
+ # available attributes:
|
|
|
+ rank: int # global rank
|
|
|
+ ranks: List[int] # global ranks in the group
|
|
|
+ world_size: int # size of the group
|
|
|
+ # difference between `local_rank` and `rank_in_group`:
|
|
|
+ # if we have a group of size 4 across two nodes:
|
|
|
+ # Process | Node | Rank | Local Rank | Rank in Group
|
|
|
+ # 0 | 0 | 0 | 0 | 0
|
|
|
+ # 1 | 0 | 1 | 1 | 1
|
|
|
+ # 2 | 1 | 2 | 0 | 2
|
|
|
+ # 3 | 1 | 3 | 1 | 3
|
|
|
+ local_rank: int # local rank used to assign devices
|
|
|
+ rank_in_group: int # rank inside the group
|
|
|
+ cpu_group: ProcessGroup # group for CPU communication
|
|
|
+ device_group: ProcessGroup # group for device communication
|
|
|
+ use_pynccl: bool # a hint of whether to use PyNccl
|
|
|
+ use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
|
|
+ # communicators are only created for world size > 1
|
|
|
+ pynccl_comm: Optional[Any] # PyNccl communicator
|
|
|
+ ca_comm: Optional[Any] # Custom allreduce communicator
|
|
|
+ mq_broadcaster: Optional[Any] # shared memory broadcaster
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ group_ranks: List[List[int]],
|
|
|
+ local_rank: int,
|
|
|
+ torch_distributed_backend: Union[str, Backend],
|
|
|
+ use_pynccl: bool,
|
|
|
+ use_custom_allreduce: bool,
|
|
|
+ use_tpu_communicator: bool,
|
|
|
+ use_message_queue_broadcaster: bool = False,
|
|
|
+ ):
|
|
|
+
|
|
|
+ self.rank = torch.distributed.get_rank()
|
|
|
+ self.local_rank = local_rank
|
|
|
+ self.device_group = None
|
|
|
+ self.cpu_group = None
|
|
|
+
|
|
|
+ for ranks in group_ranks:
|
|
|
+ device_group = torch.distributed.new_group(
|
|
|
+ ranks, backend=torch_distributed_backend)
|
|
|
+ # a group with `gloo` backend, to allow direct coordination between
|
|
|
+ # processes through the CPU.
|
|
|
+ cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
|
|
+ if self.rank in ranks:
|
|
|
+ self.ranks = ranks
|
|
|
+ self.world_size = len(ranks)
|
|
|
+ self.rank_in_group = ranks.index(self.rank)
|
|
|
+ self.device_group = device_group
|
|
|
+ self.cpu_group = cpu_group
|
|
|
+
|
|
|
+ assert self.cpu_group is not None
|
|
|
+ assert self.device_group is not None
|
|
|
+
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ self.device = torch.device(f"cuda:{local_rank}")
|
|
|
+ else:
|
|
|
+ self.device = torch.device("cpu")
|
|
|
+
|
|
|
+ self.use_pynccl = use_pynccl
|
|
|
+ self.use_custom_allreduce = use_custom_allreduce
|
|
|
+ self.use_tpu_communicator = use_tpu_communicator
|
|
|
+
|
|
|
+ # lazy import to avoid documentation build error
|
|
|
+ from aphrodite.distributed.device_communicators.custom_all_reduce import ( # noqa: E501
|
|
|
+ CustomAllreduce)
|
|
|
+ from aphrodite.distributed.device_communicators.pynccl import (
|
|
|
+ PyNcclCommunicator)
|
|
|
+
|
|
|
+ self.pynccl_comm: Optional[PyNcclCommunicator]
|
|
|
+ if use_pynccl and self.world_size > 1:
|
|
|
+ self.pynccl_comm = PyNcclCommunicator(
|
|
|
+ group=self.cpu_group,
|
|
|
+ device=self.device,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.pynccl_comm = None
|
|
|
+
|
|
|
+ self.ca_comm: Optional[CustomAllreduce]
|
|
|
+ if use_custom_allreduce and self.world_size > 1:
|
|
|
+ # Initialize a custom fast all-reduce implementation.
|
|
|
+ self.ca_comm = CustomAllreduce(
|
|
|
+ group=self.cpu_group,
|
|
|
+ device=self.device,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ self.ca_comm = None
|
|
|
+
|
|
|
+ from aphrodite.distributed.device_communicators.tpu_communicator import ( # noqa: E501
|
|
|
+ TpuCommunicator)
|
|
|
+ self.tpu_communicator: Optional[TpuCommunicator]
|
|
|
+ if use_tpu_communicator and self.world_size > 1:
|
|
|
+ self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
|
|
+
|
|
|
+ from aphrodite.distributed.device_communicators.shm_broadcast import (
|
|
|
+ MessageQueue)
|
|
|
+ self.mq_broadcaster: Optional[MessageQueue] = None
|
|
|
+ if use_message_queue_broadcaster and self.world_size > 1:
|
|
|
+ self.mq_broadcaster = MessageQueue.create_from_process_group(
|
|
|
+ self.cpu_group, 1 << 22, 6)
|
|
|
+
|
|
|
+ @property
|
|
|
+ def first_rank(self):
|
|
|
+ """Return the global rank of the first process in the group"""
|
|
|
+ return self.ranks[0]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def last_rank(self):
|
|
|
+ """Return the global rank of the last process in the group"""
|
|
|
+ return self.ranks[-1]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_first_rank(self):
|
|
|
+ """Return whether the caller is the first process in the group"""
|
|
|
+ return self.rank == self.first_rank
|
|
|
+
|
|
|
+ @property
|
|
|
+ def is_last_rank(self):
|
|
|
+ """Return whether the caller is the last process in the group"""
|
|
|
+ return self.rank == self.last_rank
|
|
|
+
|
|
|
+ @property
|
|
|
+ def next_rank(self):
|
|
|
+ """Return the global rank of the process that follows the caller"""
|
|
|
+ rank_in_group = self.rank_in_group
|
|
|
+ world_size = self.world_size
|
|
|
+ return self.ranks[(rank_in_group + 1) % world_size]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def prev_rank(self):
|
|
|
+ """Return the global rank of the process that precedes the caller"""
|
|
|
+ rank_in_group = self.rank_in_group
|
|
|
+ world_size = self.world_size
|
|
|
+ return self.ranks[(rank_in_group - 1) % world_size]
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def graph_capture(
|
|
|
+ self, graph_capture_context: Optional[GraphCaptureContext] = None):
|
|
|
+ if graph_capture_context is None:
|
|
|
+ stream = torch.cuda.Stream()
|
|
|
+ graph_capture_context = GraphCaptureContext(stream)
|
|
|
+ else:
|
|
|
+ stream = graph_capture_context.stream
|
|
|
+
|
|
|
+ ca_comm = self.ca_comm
|
|
|
+ maybe_ca_context = nullcontext(
|
|
|
+ ) if ca_comm is None else ca_comm.capture()
|
|
|
+
|
|
|
+ # ensure all initialization operations complete before attempting to
|
|
|
+ # capture the graph on another stream
|
|
|
+ curr_stream = torch.cuda.current_stream()
|
|
|
+ if curr_stream != stream:
|
|
|
+ stream.wait_stream(curr_stream)
|
|
|
+
|
|
|
+ with torch.cuda.stream(stream), maybe_ca_context:
|
|
|
+ # In graph mode, we have to be very careful about the collective
|
|
|
+ # operations. The current status is:
|
|
|
+ # allreduce \ Mode | Eager | Graph |
|
|
|
+ # --------------------------------------------
|
|
|
+ # custom allreduce | enabled | enabled |
|
|
|
+ # PyNccl | disabled| enabled |
|
|
|
+ # torch.distributed | enabled | disabled|
|
|
|
+ #
|
|
|
+ # Note that custom allreduce will have a runtime check, if the
|
|
|
+ # tensor size is too large, it will fallback to the next
|
|
|
+ # available option.
|
|
|
+ # In summary: When using CUDA graph, we use
|
|
|
+ # either custom all-reduce kernel or pynccl. When not using
|
|
|
+ # CUDA graph, we use either custom all-reduce kernel or
|
|
|
+ # PyTorch NCCL. We always prioritize using custom all-reduce
|
|
|
+ # kernel but fall back to PyTorch or pynccl if it is
|
|
|
+ # disabled or not supported.
|
|
|
+ pynccl_comm = self.pynccl_comm
|
|
|
+ maybe_pynccl_context: Any
|
|
|
+ if not pynccl_comm:
|
|
|
+ maybe_pynccl_context = nullcontext()
|
|
|
+ else:
|
|
|
+ maybe_pynccl_context = pynccl_comm.change_state(
|
|
|
+ enable=True, stream=torch.cuda.current_stream())
|
|
|
+ with maybe_pynccl_context:
|
|
|
+ yield graph_capture_context
|
|
|
+
|
|
|
+ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ NOTE: This operation will be applied in-place or out-of-place.
|
|
|
+ Always assume this function modifies its input, but use the return
|
|
|
+ value as the output.
|
|
|
+ """
|
|
|
+ ca_comm = self.ca_comm
|
|
|
+
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if self.world_size == 1:
|
|
|
+ return input_
|
|
|
+
|
|
|
+ # For TPUs, use TPU communicator.
|
|
|
+ tpu_comm = self.tpu_communicator
|
|
|
+ if tpu_comm is not None and not tpu_comm.disabled:
|
|
|
+ return tpu_comm.all_reduce(input_)
|
|
|
+
|
|
|
+ if ca_comm is not None:
|
|
|
+ out = ca_comm.custom_all_reduce(input_)
|
|
|
+ if out is not None:
|
|
|
+ return out
|
|
|
+ pynccl_comm = self.pynccl_comm
|
|
|
+ if (pynccl_comm is not None and not pynccl_comm.disabled):
|
|
|
+ pynccl_comm.all_reduce(input_)
|
|
|
+ elif input_.is_cpu:
|
|
|
+ import intel_extension_for_pytorch as ipex
|
|
|
+ ipex.distributed.all_reduce(input_, group=self.device_group)
|
|
|
+ else:
|
|
|
+ torch.distributed.all_reduce(input_, group=self.device_group)
|
|
|
+ return input_
|
|
|
+
|
|
|
+ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
|
+ world_size = self.world_size
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if world_size == 1:
|
|
|
+ return input_
|
|
|
+ assert -input_.dim() <= dim < input_.dim(), (
|
|
|
+ f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
|
|
+
|
|
|
+ # For TPUs, use TPU communicator.
|
|
|
+ tpu_comm = self.tpu_communicator
|
|
|
+ if tpu_comm is not None and not tpu_comm.disabled:
|
|
|
+ return tpu_comm.all_gather(input_, dim)
|
|
|
+
|
|
|
+ if dim < 0:
|
|
|
+ # Convert negative dim to positive.
|
|
|
+ dim += input_.dim()
|
|
|
+ input_size = input_.size()
|
|
|
+ # Allocate output tensor.
|
|
|
+ output_tensor = torch.empty((world_size, ) + input_size,
|
|
|
+ dtype=input_.dtype,
|
|
|
+ device=input_.device)
|
|
|
+ # All-gather.
|
|
|
+ torch.distributed.all_gather_into_tensor(output_tensor,
|
|
|
+ input_,
|
|
|
+ group=self.device_group)
|
|
|
+ # Reshape
|
|
|
+ output_tensor = output_tensor.movedim(0, dim)
|
|
|
+ output_tensor = output_tensor.reshape(input_size[:dim] +
|
|
|
+ (world_size *
|
|
|
+ input_size[dim], ) +
|
|
|
+ input_size[dim + 1:])
|
|
|
+ return output_tensor
|
|
|
+
|
|
|
+ def gather(self,
|
|
|
+ input_: torch.Tensor,
|
|
|
+ dst: int = 0,
|
|
|
+ dim: int = -1) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ NOTE: We assume that the input tensor is on the same device across
|
|
|
+ all the ranks.
|
|
|
+ NOTE: `dst` is the local rank of the destination rank.
|
|
|
+ """
|
|
|
+ world_size = self.world_size
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if world_size == 1:
|
|
|
+ return input_
|
|
|
+ assert -input_.dim() <= dim < input_.dim(), (
|
|
|
+ f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
|
|
+ if dim < 0:
|
|
|
+ # Convert negative dim to positive.
|
|
|
+ dim += input_.dim()
|
|
|
+ # Allocate output tensor.
|
|
|
+ if self.rank_in_group == dst:
|
|
|
+ gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
|
+ else:
|
|
|
+ gather_list = None
|
|
|
+ # Gather.
|
|
|
+ torch.distributed.gather(input_,
|
|
|
+ gather_list,
|
|
|
+ dst=self.ranks[dst],
|
|
|
+ group=self.device_group)
|
|
|
+ if self.rank_in_group == dst:
|
|
|
+ output_tensor = torch.cat(gather_list, dim=dim)
|
|
|
+ else:
|
|
|
+ output_tensor = None
|
|
|
+ return output_tensor
|
|
|
+
|
|
|
+ def broadcast(self, input_: torch.Tensor, src: int = 0):
|
|
|
+ """Broadcast the input tensor.
|
|
|
+ NOTE: `src` is the local rank of the source rank.
|
|
|
+ """
|
|
|
+ assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
+
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if self.world_size == 1:
|
|
|
+ return input_
|
|
|
+ # Broadcast.
|
|
|
+ torch.distributed.broadcast(input_,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=self.device_group)
|
|
|
+ return input_
|
|
|
+
|
|
|
+ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
|
|
|
+ """Broadcast the input object.
|
|
|
+ NOTE: `src` is the local rank of the source rank.
|
|
|
+ """
|
|
|
+ assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
+
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if self.world_size == 1:
|
|
|
+ return obj
|
|
|
+ if self.mq_broadcaster is not None:
|
|
|
+ assert src == 0, "Message queue broadcaster only supports src=0"
|
|
|
+ return self.mq_broadcaster.broadcast_object(obj)
|
|
|
+ if self.rank_in_group == src:
|
|
|
+ torch.distributed.broadcast_object_list([obj],
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=self.cpu_group)
|
|
|
+ return obj
|
|
|
+ else:
|
|
|
+ recv = [None]
|
|
|
+ torch.distributed.broadcast_object_list(recv,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=self.cpu_group)
|
|
|
+ return recv[0]
|
|
|
+
|
|
|
+ def broadcast_object_list(self,
|
|
|
+ obj_list: List[Any],
|
|
|
+ src: int = 0,
|
|
|
+ group: Optional[ProcessGroup] = None):
|
|
|
+ """Broadcast the input object list.
|
|
|
+ NOTE: `src` is the local rank of the source rank.
|
|
|
+ """
|
|
|
+ assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
+
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if self.world_size == 1:
|
|
|
+ return obj_list
|
|
|
+ # Broadcast.
|
|
|
+ torch.distributed.broadcast_object_list(obj_list,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=self.device_group)
|
|
|
+ return obj_list
|
|
|
+
|
|
|
+ def send_object(self, obj: Any, dst: int) -> None:
|
|
|
+ """Send the input object list to the destination rank."""
|
|
|
+ """NOTE: `dst` is the local rank of the destination rank."""
|
|
|
+
|
|
|
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
|
|
+
|
|
|
+ assert dst != self.rank_in_group, (
|
|
|
+ "Invalid destination rank. Destination rank is the same "
|
|
|
+ "as the current rank.")
|
|
|
+
|
|
|
+ # Serialize object to tensor and get the size as well
|
|
|
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
|
|
+
|
|
|
+ size_tensor = torch.tensor([object_tensor.numel()],
|
|
|
+ dtype=torch.long,
|
|
|
+ device="cpu")
|
|
|
+
|
|
|
+ # Send object size
|
|
|
+
|
|
|
+ torch.distributed.send(size_tensor,
|
|
|
+ dst=self.ranks[dst],
|
|
|
+ group=self.cpu_group)
|
|
|
+
|
|
|
+ # Send object
|
|
|
+ torch.distributed.send(object_tensor,
|
|
|
+ dst=self.ranks[dst],
|
|
|
+ group=self.cpu_group)
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+ def recv_object(self, src: int) -> Any:
|
|
|
+ """Receive the input object list from the source rank."""
|
|
|
+ """NOTE: `src` is the local rank of the source rank."""
|
|
|
+
|
|
|
+ assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
+
|
|
|
+ assert src != self.rank_in_group, (
|
|
|
+ "Invalid source rank. Source rank is the same as the current rank."
|
|
|
+ )
|
|
|
+
|
|
|
+ size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
|
|
+
|
|
|
+ # Receive object size
|
|
|
+ rank_size = torch.distributed.recv(size_tensor,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=self.cpu_group)
|
|
|
+
|
|
|
+ # Tensor to receive serialized objects into.
|
|
|
+ object_tensor = torch.empty( # type: ignore[call-overload]
|
|
|
+ size_tensor.item(), # type: ignore[arg-type]
|
|
|
+ dtype=torch.uint8,
|
|
|
+ device="cpu")
|
|
|
+
|
|
|
+ rank_object = torch.distributed.recv(object_tensor,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=self.cpu_group)
|
|
|
+
|
|
|
+ assert rank_object == rank_size, (
|
|
|
+ "Received object sender rank does not match the size sender rank.")
|
|
|
+
|
|
|
+ obj = pickle.loads(object_tensor.numpy().tobytes())
|
|
|
+
|
|
|
+ return obj
|
|
|
+
|
|
|
+ def broadcast_tensor_dict(
|
|
|
+ self,
|
|
|
+ tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
|
|
|
+ src: int = 0,
|
|
|
+ group: Optional[ProcessGroup] = None,
|
|
|
+ metadata_group: Optional[ProcessGroup] = None
|
|
|
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
|
|
+ """Broadcast the input tensor dictionary.
|
|
|
+ NOTE: `src` is the local rank of the source rank.
|
|
|
+ """
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if (not torch.distributed.is_initialized() or self.world_size == 1):
|
|
|
+ return tensor_dict
|
|
|
+
|
|
|
+ group = self.device_group
|
|
|
+ metadata_group = self.cpu_group
|
|
|
+ assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
|
|
|
-# Tensor model parallel group that the current rank belongs to.
|
|
|
-_TENSOR_MODEL_PARALLEL_GROUP = None
|
|
|
-# Pipeline model parallel group that the current rank belongs to.
|
|
|
-_PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
|
+ rank_in_group = self.rank_in_group
|
|
|
+ if rank_in_group == src:
|
|
|
+ metadata_list: List[Tuple[Any, Any]] = []
|
|
|
+ assert isinstance(
|
|
|
+ tensor_dict,
|
|
|
+ dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
|
|
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
|
|
+ # `metadata_list` lives in CPU memory.
|
|
|
+ # `broadcast_object_list` has serialization & deserialization,
|
|
|
+ # all happening on CPU. Therefore, we can use the CPU group.
|
|
|
+ self.broadcast_object(metadata_list, src=src)
|
|
|
+ async_handles = []
|
|
|
+ for tensor in tensor_list:
|
|
|
+ if tensor.numel() == 0:
|
|
|
+ # Skip broadcasting empty tensors.
|
|
|
+ continue
|
|
|
+ if tensor.is_cpu:
|
|
|
+ # use metadata_group for CPU tensors
|
|
|
+ handle = torch.distributed.broadcast(tensor,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=metadata_group,
|
|
|
+ async_op=True)
|
|
|
+ else:
|
|
|
+ # use group for GPU tensors
|
|
|
+ handle = torch.distributed.broadcast(tensor,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=group,
|
|
|
+ async_op=True)
|
|
|
+ async_handles.append(handle)
|
|
|
+ for async_handle in async_handles:
|
|
|
+ async_handle.wait()
|
|
|
|
|
|
-# when people blindly call `torch.distributed.all_reduce` etc,
|
|
|
-# it will use this group. It is initialized with the `backend`
|
|
|
-# parameter of `init_distributed_environment` below.
|
|
|
-# Essentially, this is `torch.distributed.group.WORLD`.
|
|
|
-# We leave a line here to note that this is device-specific.
|
|
|
-# Note that this variable is not safe to use, because when users
|
|
|
-# call `init_distributed_environment` first, and then destroy
|
|
|
-# the process group themselves, this variable will keep a reference to the
|
|
|
-# destroyed process group, which is not useful.
|
|
|
-_DEVICE_WORLD_GROUP = None
|
|
|
+ else:
|
|
|
+ metadata_list = self.broadcast_object(None, src=src)
|
|
|
+ tensor_dict = {}
|
|
|
+ async_handles = []
|
|
|
+ for key, value in metadata_list:
|
|
|
+ if isinstance(value, TensorMetadata):
|
|
|
+ tensor = torch.empty(value.size,
|
|
|
+ dtype=value.dtype,
|
|
|
+ device=value.device)
|
|
|
+ if tensor.numel() == 0:
|
|
|
+ # Skip broadcasting empty tensors.
|
|
|
+ tensor_dict[key] = tensor
|
|
|
+ continue
|
|
|
+ if tensor.is_cpu:
|
|
|
+ # use metadata_group for CPU tensors
|
|
|
+ handle = torch.distributed.broadcast(
|
|
|
+ tensor,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=metadata_group,
|
|
|
+ async_op=True)
|
|
|
+ else:
|
|
|
+ # use group for GPU tensors
|
|
|
+ handle = torch.distributed.broadcast(
|
|
|
+ tensor,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=group,
|
|
|
+ async_op=True)
|
|
|
+ async_handles.append(handle)
|
|
|
+ tensor_dict[key] = tensor
|
|
|
+ else:
|
|
|
+ tensor_dict[key] = value
|
|
|
+ for async_handle in async_handles:
|
|
|
+ async_handle.wait()
|
|
|
+ return tensor_dict
|
|
|
|
|
|
-# duing `init_distributed_environment`, we will also initialize a
|
|
|
-# group with `gloo` backend, to allow direct coordination between
|
|
|
-# processes through the CPU.
|
|
|
-_CPU_WORLD_GROUP = None
|
|
|
+ def send_tensor_dict(
|
|
|
+ self,
|
|
|
+ tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
|
|
+ dst: Optional[int] = None,
|
|
|
+ all_gather_group: Optional["GroupCoordinator"] = None,
|
|
|
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
|
|
+ """Send the input tensor dictionary.
|
|
|
+ NOTE: `dst` is the local rank of the source rank.
|
|
|
+ """
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if not torch.distributed.is_initialized() or self.world_size == 1:
|
|
|
+ return tensor_dict
|
|
|
|
|
|
-# In summary, after calling `init_distributed_environment`, we will
|
|
|
-# always have two groups: one for device-specific (and is the default)
|
|
|
-# and one for CPU. All processes will be part of both groups.
|
|
|
+ all_gather_size = (1 if all_gather_group is None else
|
|
|
+ all_gather_group.world_size)
|
|
|
+ all_gather_rank = (0 if all_gather_group is None else
|
|
|
+ all_gather_group.rank_in_group)
|
|
|
|
|
|
-# A list of global ranks for each pipeline group to ease calculation of the
|
|
|
-# source rank when broadcasting from the first or last pipeline stage.
|
|
|
-_PIPELINE_GLOBAL_RANKS = None
|
|
|
+ group = self.device_group
|
|
|
+ metadata_group = self.cpu_group
|
|
|
|
|
|
-_LOCAL_RANK = -1
|
|
|
+ if dst is None:
|
|
|
+ dst = (self.rank_in_group + 1) % self.world_size
|
|
|
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
|
|
|
|
|
+ metadata_list: List[Tuple[Any, Any]] = []
|
|
|
+ assert isinstance(
|
|
|
+ tensor_dict,
|
|
|
+ dict), f"Expecting a dictionary, got {type(tensor_dict)}"
|
|
|
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
|
|
+ # `metadata_list` lives in CPU memory.
|
|
|
+ # `send_object_list` has serialization & deserialization,
|
|
|
+ # all happening on CPU. Therefore, we can use the CPU group.
|
|
|
+ self.send_object(metadata_list, dst=dst)
|
|
|
+ for tensor in tensor_list:
|
|
|
+ if tensor.numel() == 0:
|
|
|
+ # Skip sending empty tensors.
|
|
|
+ continue
|
|
|
|
|
|
-def get_local_rank():
|
|
|
- global _LOCAL_RANK
|
|
|
- return _LOCAL_RANK
|
|
|
+ # send-allgather: send only a slice, then do allgather.
|
|
|
+ if (all_gather_group is not None
|
|
|
+ and tensor.numel() % all_gather_size == 0):
|
|
|
+ tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
|
|
+
|
|
|
+ if tensor.is_cpu:
|
|
|
+ # use metadata_group for CPU tensors
|
|
|
+ torch.distributed.send(tensor,
|
|
|
+ dst=self.ranks[dst],
|
|
|
+ group=metadata_group)
|
|
|
+ else:
|
|
|
+ # use group for GPU tensors
|
|
|
+ torch.distributed.send(tensor,
|
|
|
+ dst=self.ranks[dst],
|
|
|
+ group=group)
|
|
|
+ return None
|
|
|
+
|
|
|
+ def recv_tensor_dict(
|
|
|
+ self,
|
|
|
+ src: Optional[int] = None,
|
|
|
+ all_gather_group: Optional["GroupCoordinator"] = None,
|
|
|
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
|
|
|
+ """Recv the input tensor dictionary.
|
|
|
+ NOTE: `src` is the local rank of the source rank.
|
|
|
+ """
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if not torch.distributed.is_initialized() or self.world_size == 1:
|
|
|
+ return None
|
|
|
+
|
|
|
+ all_gather_size = (1 if all_gather_group is None else
|
|
|
+ all_gather_group.world_size)
|
|
|
+ all_gather_rank = (0 if all_gather_group is None else
|
|
|
+ all_gather_group.rank_in_group)
|
|
|
+
|
|
|
+ group = self.device_group
|
|
|
+ metadata_group = self.cpu_group
|
|
|
+
|
|
|
+ if src is None:
|
|
|
+ src = (self.rank_in_group - 1) % self.world_size
|
|
|
+ assert src < self.world_size, f"Invalid src rank ({src})"
|
|
|
+
|
|
|
+ recv_metadata_list = self.recv_object(src=src)
|
|
|
+ tensor_dict: Dict[str, Any] = {}
|
|
|
+ for key, value in recv_metadata_list:
|
|
|
+ if isinstance(value, TensorMetadata):
|
|
|
+ tensor = torch.empty(value.size,
|
|
|
+ dtype=value.dtype,
|
|
|
+ device=value.device)
|
|
|
+ if tensor.numel() == 0:
|
|
|
+ # Skip broadcasting empty tensors.
|
|
|
+ tensor_dict[key] = tensor
|
|
|
+ continue
|
|
|
+
|
|
|
+ # send-allgather: send only a slice, then do allgather.
|
|
|
+ use_all_gather = (all_gather_group is not None
|
|
|
+ and tensor.numel() % all_gather_size == 0)
|
|
|
+
|
|
|
+ if use_all_gather:
|
|
|
+ orig_shape = tensor.shape
|
|
|
+ tensor = tensor.reshape(all_gather_size,
|
|
|
+ -1)[all_gather_rank]
|
|
|
+
|
|
|
+ if tensor.is_cpu:
|
|
|
+ # use metadata_group for CPU tensors
|
|
|
+ torch.distributed.recv(tensor,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=metadata_group)
|
|
|
+ else:
|
|
|
+ # use group for GPU tensors
|
|
|
+ torch.distributed.recv(tensor,
|
|
|
+ src=self.ranks[src],
|
|
|
+ group=group)
|
|
|
+ if use_all_gather:
|
|
|
+ # do the allgather
|
|
|
+ tensor = all_gather_group.all_gather( # type: ignore
|
|
|
+ tensor, dim=0)
|
|
|
+ tensor = tensor.reshape(orig_shape)
|
|
|
+
|
|
|
+ tensor_dict[key] = tensor
|
|
|
+ else:
|
|
|
+ tensor_dict[key] = value
|
|
|
+ return tensor_dict
|
|
|
+
|
|
|
+ def barrier(self):
|
|
|
+ """Barrier synchronization among the group.
|
|
|
+ NOTE: don't use `device_group` here! `barrier` in NCCL is
|
|
|
+ terrible because it is internally a broadcast operation with
|
|
|
+ secretly created GPU tensors. It is easy to mess up the current
|
|
|
+ device. Use the CPU group instead.
|
|
|
+ """
|
|
|
+ torch.distributed.barrier(group=self.cpu_group)
|
|
|
+
|
|
|
+ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
|
|
+ """Sends a tensor to the destination rank in a non-blocking way"""
|
|
|
+ """NOTE: `dst` is the local rank of the destination rank."""
|
|
|
+ if dst is None:
|
|
|
+ dst = (self.rank_in_group + 1) % self.world_size
|
|
|
+
|
|
|
+ pynccl_comm = self.pynccl_comm
|
|
|
+ if pynccl_comm is not None and not pynccl_comm.disabled:
|
|
|
+ pynccl_comm.send(tensor, dst)
|
|
|
+ else:
|
|
|
+ torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
|
|
+
|
|
|
+ def recv(self,
|
|
|
+ size: torch.Size,
|
|
|
+ dtype: torch.dtype,
|
|
|
+ src: Optional[int] = None) -> torch.Tensor:
|
|
|
+ """Receives a tensor from the src rank."""
|
|
|
+ """NOTE: `src` is the local rank of the destination rank."""
|
|
|
+ if src is None:
|
|
|
+ src = (self.rank_in_group - 1) % self.world_size
|
|
|
+
|
|
|
+ tensor = torch.empty(size, dtype=dtype, device=self.device)
|
|
|
+ pynccl_comm = self.pynccl_comm
|
|
|
+ if pynccl_comm is not None and not pynccl_comm.disabled:
|
|
|
+ pynccl_comm.recv(tensor, src)
|
|
|
+ else:
|
|
|
+ torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
|
|
+ return tensor
|
|
|
+
|
|
|
+ def destroy(self):
|
|
|
+ if self.device_group is not None:
|
|
|
+ torch.distributed.destroy_process_group(self.device_group)
|
|
|
+ self.device_group = None
|
|
|
+ if self.cpu_group is not None:
|
|
|
+ torch.distributed.destroy_process_group(self.cpu_group)
|
|
|
+ self.cpu_group = None
|
|
|
+ if self.pynccl_comm is not None:
|
|
|
+ self.pynccl_comm = None
|
|
|
+ if self.ca_comm is not None:
|
|
|
+ self.ca_comm = None
|
|
|
+ if self.mq_broadcaster is not None:
|
|
|
+ self.mq_broadcaster = None
|
|
|
+
|
|
|
+
|
|
|
+_WORLD: Optional[GroupCoordinator] = None
|
|
|
+
|
|
|
+
|
|
|
+def get_world_group() -> GroupCoordinator:
|
|
|
+ assert _WORLD is not None, ("world group is not initialized")
|
|
|
+ return _WORLD
|
|
|
+
|
|
|
+
|
|
|
+def init_world_group(ranks: List[int], local_rank: int,
|
|
|
+ backend: str) -> GroupCoordinator:
|
|
|
+ return GroupCoordinator(
|
|
|
+ group_ranks=[ranks],
|
|
|
+ local_rank=local_rank,
|
|
|
+ torch_distributed_backend=backend,
|
|
|
+ use_pynccl=False,
|
|
|
+ use_custom_allreduce=False,
|
|
|
+ use_tpu_communicator=False,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def init_model_parallel_group(
|
|
|
+ group_ranks: List[List[int]],
|
|
|
+ local_rank: int,
|
|
|
+ backend: str,
|
|
|
+ use_custom_allreduce: Optional[bool] = None,
|
|
|
+ use_message_queue_broadcaster: bool = False,
|
|
|
+) -> GroupCoordinator:
|
|
|
+ if use_custom_allreduce is None:
|
|
|
+ use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
|
|
+ return GroupCoordinator(
|
|
|
+ group_ranks=group_ranks,
|
|
|
+ local_rank=local_rank,
|
|
|
+ torch_distributed_backend=backend,
|
|
|
+ use_pynccl=True,
|
|
|
+ use_custom_allreduce=use_custom_allreduce,
|
|
|
+ use_tpu_communicator=True,
|
|
|
+ use_message_queue_broadcaster=use_message_queue_broadcaster,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+_TP: Optional[GroupCoordinator] = None
|
|
|
+
|
|
|
+
|
|
|
+def get_tp_group() -> GroupCoordinator:
|
|
|
+ assert _TP is not None, ("tensor model parallel group is not initialized")
|
|
|
+ return _TP
|
|
|
+
|
|
|
+
|
|
|
+# kept for backward compatibility
|
|
|
+get_tensor_model_parallel_group = get_tp_group
|
|
|
+
|
|
|
+_PP: Optional[GroupCoordinator] = None
|
|
|
+
|
|
|
+
|
|
|
+def get_pp_group() -> GroupCoordinator:
|
|
|
+ assert _PP is not None, (
|
|
|
+ "pipeline model parallel group is not initialized")
|
|
|
+ return _PP
|
|
|
+
|
|
|
+
|
|
|
+# kept for backward compatibility
|
|
|
+get_pipeline_model_parallel_group = get_pp_group
|
|
|
+
|
|
|
+
|
|
|
+@contextmanager
|
|
|
+def graph_capture():
|
|
|
+ """
|
|
|
+ `graph_capture` is a context manager which should surround the code that
|
|
|
+ is capturing the CUDA graph. Its main purpose is to ensure that the
|
|
|
+ some operations will be run after the graph is captured, before the graph
|
|
|
+ is replayed. It returns a `GraphCaptureContext` object which contains the
|
|
|
+ necessary data for the graph capture. Currently, it only contains the
|
|
|
+ stream that the graph capture is running on. This stream is set to the
|
|
|
+ current CUDA stream when the context manager is entered and reset to the
|
|
|
+ default stream when the context manager is exited. This is to ensure that
|
|
|
+ the graph capture is running on a separate stream from the default stream,
|
|
|
+ in order to explicitly distinguish the kernels to capture
|
|
|
+ from other kernels possibly launched on background in the default stream.
|
|
|
+ """
|
|
|
+ with get_tp_group().graph_capture() as context, get_pp_group(
|
|
|
+ ).graph_capture(context):
|
|
|
+ yield context
|
|
|
+
|
|
|
+
|
|
|
+_ENABLE_CUSTOM_ALL_REDUCE = True
|
|
|
+
|
|
|
+
|
|
|
+def set_custom_all_reduce(enable: bool):
|
|
|
+ global _ENABLE_CUSTOM_ALL_REDUCE
|
|
|
+ _ENABLE_CUSTOM_ALL_REDUCE = enable
|
|
|
|
|
|
|
|
|
def init_distributed_environment(
|
|
@@ -54,8 +840,9 @@ def init_distributed_environment(
|
|
|
local_rank: int = -1,
|
|
|
backend: str = "nccl",
|
|
|
):
|
|
|
- logger.debug(f"{world_size=} {rank=} {local_rank=} "
|
|
|
- f"{distributed_init_method=} {backend=}")
|
|
|
+ logger.debug(
|
|
|
+ f"world_size={world_size} rank={rank} local_rank={local_rank} "
|
|
|
+ f"distributed_init_method={distributed_init_method} backend={backend}")
|
|
|
if not torch.distributed.is_initialized():
|
|
|
assert distributed_init_method is not None, (
|
|
|
"distributed_init_method must be provided when initializing "
|
|
@@ -66,13 +853,23 @@ def init_distributed_environment(
|
|
|
init_method=distributed_init_method,
|
|
|
world_size=world_size,
|
|
|
rank=rank)
|
|
|
- global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
|
|
|
- _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
|
|
|
+ # set the local rank
|
|
|
+ # local_rank is not available in torch ProcessGroup,
|
|
|
+ # see https://github.com/pytorch/pytorch/issues/122816
|
|
|
+ if local_rank == -1:
|
|
|
+ # local rank not set, this usually happens in single-node
|
|
|
+ # setting, where we can use rank as local rank
|
|
|
+ if distributed_init_method == "env://":
|
|
|
+ local_rank = os.getenv("LOCAL_RANK", rank)
|
|
|
+ else:
|
|
|
+ local_rank = rank
|
|
|
+ global _WORLD
|
|
|
+ if _WORLD is None:
|
|
|
ranks = list(range(torch.distributed.get_world_size()))
|
|
|
- _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
|
|
|
- backend="gloo")
|
|
|
- global _LOCAL_RANK
|
|
|
- _LOCAL_RANK = local_rank
|
|
|
+ _WORLD = init_world_group(ranks, local_rank, backend)
|
|
|
+ else:
|
|
|
+ assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
|
|
+ "world group already initialized with a different world size")
|
|
|
|
|
|
|
|
|
def initialize_model_parallel(
|
|
@@ -105,8 +902,8 @@ def initialize_model_parallel(
|
|
|
# Get world size and rank. Ensure some consistencies.
|
|
|
assert torch.distributed.is_initialized()
|
|
|
world_size: int = torch.distributed.get_world_size()
|
|
|
- # get the backend of _DEVICE_WORLD_GROUP
|
|
|
- backend = backend or torch.distributed.get_backend()
|
|
|
+ backend = backend or torch.distributed.get_backend(
|
|
|
+ get_world_group().device_group)
|
|
|
|
|
|
if (world_size !=
|
|
|
tensor_model_parallel_size * pipeline_model_parallel_size):
|
|
@@ -115,34 +912,39 @@ def initialize_model_parallel(
|
|
|
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
|
|
|
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
|
|
|
|
|
|
+ # Build the tensor model-parallel groups.
|
|
|
num_tensor_model_parallel_groups: int = (world_size //
|
|
|
tensor_model_parallel_size)
|
|
|
- num_pipeline_model_parallel_groups: int = (world_size //
|
|
|
- pipeline_model_parallel_size)
|
|
|
- rank = torch.distributed.get_rank()
|
|
|
-
|
|
|
- # Build the tensor model-parallel groups.
|
|
|
- global _TENSOR_MODEL_PARALLEL_GROUP
|
|
|
- assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
|
|
|
- "tensor model parallel group is already initialized")
|
|
|
+ global _TP
|
|
|
+ assert _TP is None, ("tensor model parallel group is already initialized")
|
|
|
+ group_ranks = []
|
|
|
for i in range(num_tensor_model_parallel_groups):
|
|
|
- ranks = range(i * tensor_model_parallel_size,
|
|
|
- (i + 1) * tensor_model_parallel_size)
|
|
|
- group = torch.distributed.new_group(ranks, backend=backend)
|
|
|
- if rank in ranks:
|
|
|
- _TENSOR_MODEL_PARALLEL_GROUP = group
|
|
|
+ ranks = list(
|
|
|
+ range(i * tensor_model_parallel_size,
|
|
|
+ (i + 1) * tensor_model_parallel_size))
|
|
|
+ group_ranks.append(ranks)
|
|
|
+
|
|
|
+ # message queue broadcaster is only used in tensor model parallel group
|
|
|
+ _TP = init_model_parallel_group(group_ranks,
|
|
|
+ get_world_group().local_rank,
|
|
|
+ backend,
|
|
|
+ use_message_queue_broadcaster=True)
|
|
|
|
|
|
# Build the pipeline model-parallel groups.
|
|
|
- global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
|
- global _PIPELINE_GLOBAL_RANKS
|
|
|
- assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
|
|
|
+ num_pipeline_model_parallel_groups: int = (world_size //
|
|
|
+ pipeline_model_parallel_size)
|
|
|
+ global _PP
|
|
|
+ assert _PP is None, (
|
|
|
"pipeline model parallel group is already initialized")
|
|
|
+ group_ranks = []
|
|
|
for i in range(num_pipeline_model_parallel_groups):
|
|
|
- ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
|
|
- group = torch.distributed.new_group(ranks, backend=backend)
|
|
|
- if rank in ranks:
|
|
|
- _PIPELINE_MODEL_PARALLEL_GROUP = group
|
|
|
- _PIPELINE_GLOBAL_RANKS = ranks
|
|
|
+ ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
|
|
+ group_ranks.append(ranks)
|
|
|
+ # pipeline parallel does not need custom allreduce
|
|
|
+ _PP = init_model_parallel_group(group_ranks,
|
|
|
+ get_world_group().local_rank,
|
|
|
+ backend,
|
|
|
+ use_custom_allreduce=False)
|
|
|
|
|
|
|
|
|
def ensure_model_parallel_initialized(
|
|
@@ -154,8 +956,8 @@ def ensure_model_parallel_initialized(
|
|
|
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
|
|
values if the model parallel groups are initialized.
|
|
|
"""
|
|
|
- # get the backend of _DEVICE_WORLD_GROUP
|
|
|
- backend = backend or torch.distributed.get_backend()
|
|
|
+ backend = backend or torch.distributed.get_backend(
|
|
|
+ get_world_group().device_group)
|
|
|
if not model_parallel_is_initialized():
|
|
|
initialize_model_parallel(tensor_model_parallel_size,
|
|
|
pipeline_model_parallel_size, backend)
|
|
@@ -166,149 +968,167 @@ def ensure_model_parallel_initialized(
|
|
|
), ("tensor parallel group already initialized, but of unexpected size: "
|
|
|
f"{get_tensor_model_parallel_world_size()=} vs. "
|
|
|
f"{tensor_model_parallel_size=}")
|
|
|
- assert (get_pipeline_model_parallel_world_size(
|
|
|
- ) == pipeline_model_parallel_size), (
|
|
|
+ pp_world_size = get_pp_group().world_size
|
|
|
+ assert (pp_world_size == pipeline_model_parallel_size), (
|
|
|
"pipeline parallel group already initialized, but of unexpected size: "
|
|
|
- f"{get_pipeline_model_parallel_world_size()=} vs. "
|
|
|
+ f"{pp_world_size=} vs. "
|
|
|
f"{pipeline_model_parallel_size=}")
|
|
|
|
|
|
|
|
|
def model_parallel_is_initialized():
|
|
|
"""Check if tensor and pipeline parallel groups are initialized."""
|
|
|
- return (_TENSOR_MODEL_PARALLEL_GROUP is not None
|
|
|
- and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
|
|
-
|
|
|
+ return (_TP is not None and _PP is not None)
|
|
|
|
|
|
-def get_cpu_world_group():
|
|
|
- """Get the CPU world group."""
|
|
|
- assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
|
|
|
- return _CPU_WORLD_GROUP
|
|
|
|
|
|
+_TP_STATE_PATCHED = False
|
|
|
|
|
|
-def get_tensor_model_parallel_group():
|
|
|
- """Get the tensor model parallel group the caller rank belongs to."""
|
|
|
- assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
|
|
|
- "tenosr model parallel group is not initialized")
|
|
|
- return _TENSOR_MODEL_PARALLEL_GROUP
|
|
|
|
|
|
+@contextmanager
|
|
|
+def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
|
|
+ """Patch the tp group temporarily until this function ends.
|
|
|
+ This method is for draft workers of speculative decoding to run draft model
|
|
|
+ with different tp degree from that of target model workers.
|
|
|
+ Args:
|
|
|
+ tp_group (GroupCoordinator): the tp group coordinator
|
|
|
+ """
|
|
|
+ global _TP_STATE_PATCHED
|
|
|
+ assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
|
|
|
|
|
|
-def get_pipeline_model_parallel_group():
|
|
|
- """Get the pipeline model parallel group the caller rank belongs to."""
|
|
|
- assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
|
|
|
- "pipeline model parallel group is not initialized")
|
|
|
- return _PIPELINE_MODEL_PARALLEL_GROUP
|
|
|
+ _TP_STATE_PATCHED = True
|
|
|
+ old_tp_group = get_tp_group()
|
|
|
+ global _TP
|
|
|
+ _TP = tp_group
|
|
|
+ try:
|
|
|
+ yield
|
|
|
+ finally:
|
|
|
+ # restore the original state
|
|
|
+ _TP_STATE_PATCHED = False
|
|
|
+ _TP = old_tp_group
|
|
|
|
|
|
|
|
|
def get_tensor_model_parallel_world_size():
|
|
|
"""Return world size for the tensor model parallel group."""
|
|
|
- return torch.distributed.get_world_size(
|
|
|
- group=get_tensor_model_parallel_group())
|
|
|
-
|
|
|
-
|
|
|
-def get_pipeline_model_parallel_world_size():
|
|
|
- """Return world size for the pipeline model parallel group."""
|
|
|
- return torch.distributed.get_world_size(
|
|
|
- group=get_pipeline_model_parallel_group())
|
|
|
+ return get_tp_group().world_size
|
|
|
|
|
|
|
|
|
def get_tensor_model_parallel_rank():
|
|
|
"""Return my rank for the tensor model parallel group."""
|
|
|
- return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
|
|
+ return get_tp_group().rank_in_group
|
|
|
+
|
|
|
+
|
|
|
+def destroy_model_parallel():
|
|
|
+ """Set the groups to none and destroy them."""
|
|
|
+ global _TP
|
|
|
+ if _TP:
|
|
|
+ _TP.destroy()
|
|
|
+ _TP = None
|
|
|
|
|
|
+ global _PP
|
|
|
+ if _PP:
|
|
|
+ _PP.destroy()
|
|
|
+ _PP = None
|
|
|
|
|
|
-def get_pipeline_model_parallel_rank():
|
|
|
- """Return my rank for the pipeline model parallel group."""
|
|
|
- return torch.distributed.get_rank(
|
|
|
- group=get_pipeline_model_parallel_group())
|
|
|
|
|
|
+def destroy_distributed_environment():
|
|
|
+ global _WORLD
|
|
|
+ if _WORLD:
|
|
|
+ _WORLD.destroy()
|
|
|
+ _WORLD = None
|
|
|
+ if torch.distributed.is_initialized():
|
|
|
+ torch.distributed.destroy_process_group()
|
|
|
|
|
|
-def get_tensor_model_parallel_src_rank():
|
|
|
- """Calculate the global rank corresponding to the first local rank
|
|
|
- in the tensor model parallel group."""
|
|
|
- global_rank = torch.distributed.get_rank()
|
|
|
- local_world_size = get_tensor_model_parallel_world_size()
|
|
|
- return (global_rank // local_world_size) * local_world_size
|
|
|
|
|
|
+def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
|
|
+ """
|
|
|
+ This is a collective operation that returns if each rank is in the same node
|
|
|
+ as the source rank. It tests if processes are attached to the same
|
|
|
+ memory system (shared access to shared memory).
|
|
|
+ """
|
|
|
+ assert torch.distributed.get_backend(
|
|
|
+ pg) != torch.distributed.Backend.NCCL, (
|
|
|
+ "in_the_same_node_as should be tested with a non-NCCL group.")
|
|
|
+ # local rank inside the group
|
|
|
+ rank = torch.distributed.get_rank(group=pg)
|
|
|
+ world_size = torch.distributed.get_world_size(group=pg)
|
|
|
|
|
|
-def get_pipeline_model_parallel_first_rank():
|
|
|
- """Return the global rank of the first process in the pipeline for the
|
|
|
- current tensor parallel group"""
|
|
|
- assert _PIPELINE_GLOBAL_RANKS is not None, (
|
|
|
- "Pipeline parallel group is not initialized")
|
|
|
- return _PIPELINE_GLOBAL_RANKS[0]
|
|
|
+ # local tensor in each process to store the result
|
|
|
+ is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
|
|
|
|
|
+ # global ranks of the processes in the group
|
|
|
+ ranks = torch.distributed.get_process_group_ranks(pg)
|
|
|
|
|
|
-def get_pipeline_model_parallel_last_rank():
|
|
|
- """Return the global rank of the last process in the pipeline for the
|
|
|
- current tensor parallel group"""
|
|
|
- assert _PIPELINE_GLOBAL_RANKS is not None, (
|
|
|
- "Pipeline parallel group is not initialized")
|
|
|
- last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
|
|
- return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
|
|
+ magic_message = b"magic_message"
|
|
|
+ shm = None
|
|
|
|
|
|
+ try:
|
|
|
+ with contextlib.suppress(OSError):
|
|
|
+ if rank == source_rank:
|
|
|
+ # create a shared memory segment
|
|
|
+ shm = shared_memory.SharedMemory(create=True, size=128)
|
|
|
+ shm.buf[:len(magic_message)] = magic_message
|
|
|
+ torch.distributed.broadcast_object_list([shm.name],
|
|
|
+ src=ranks[source_rank],
|
|
|
+ group=pg)
|
|
|
+ is_in_the_same_node[rank] = 1
|
|
|
+ else:
|
|
|
+ # try to open the shared memory segment
|
|
|
+ recv = [None]
|
|
|
+ torch.distributed.broadcast_object_list(recv,
|
|
|
+ src=ranks[source_rank],
|
|
|
+ group=pg)
|
|
|
+ name = recv[0]
|
|
|
+ # fix to https://stackoverflow.com/q/62748654/9191338
|
|
|
+ # Python incorrectly tracks shared memory even if it is not
|
|
|
+ # created by the process. The following patch is a workaround.
|
|
|
+ with patch("multiprocessing.resource_tracker.register",
|
|
|
+ lambda *args, **kwargs: None):
|
|
|
+ shm = shared_memory.SharedMemory(name=name)
|
|
|
+ if shm.buf[:len(magic_message)] == magic_message:
|
|
|
+ is_in_the_same_node[rank] = 1
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error ignored in is_in_the_same_node: {e}")
|
|
|
+ finally:
|
|
|
+ if shm:
|
|
|
+ shm.close()
|
|
|
|
|
|
-def get_pipeline_model_parallel_next_rank():
|
|
|
- """Return the global rank that follows the caller in the pipeline"""
|
|
|
- assert _PIPELINE_GLOBAL_RANKS is not None, (
|
|
|
- "Pipeline parallel group is not initialized")
|
|
|
- rank_in_pipeline = get_pipeline_model_parallel_rank()
|
|
|
- world_size = get_pipeline_model_parallel_world_size()
|
|
|
- return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
|
|
+ torch.distributed.barrier(group=pg)
|
|
|
|
|
|
+ # clean up the shared memory segment
|
|
|
+ with contextlib.suppress(OSError):
|
|
|
+ if rank == source_rank and shm:
|
|
|
+ shm.unlink()
|
|
|
+ torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
|
|
|
|
|
-def get_pipeline_model_parallel_prev_rank():
|
|
|
- """Return the global rank that precedes the caller in the pipeline"""
|
|
|
- assert _PIPELINE_GLOBAL_RANKS is not None, (
|
|
|
- "Pipeline parallel group is not initialized")
|
|
|
- rank_in_pipeline = get_pipeline_model_parallel_rank()
|
|
|
- world_size = get_pipeline_model_parallel_world_size()
|
|
|
- return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
|
|
+ return [x == 1 for x in is_in_the_same_node.tolist()]
|
|
|
|
|
|
|
|
|
-def destroy_model_parallel():
|
|
|
- """Set the groups to none and destroy them."""
|
|
|
- global _TENSOR_MODEL_PARALLEL_GROUP
|
|
|
- if _TENSOR_MODEL_PARALLEL_GROUP:
|
|
|
- torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
|
|
|
- _TENSOR_MODEL_PARALLEL_GROUP = None
|
|
|
- global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
|
- if _PIPELINE_MODEL_PARALLEL_GROUP:
|
|
|
- torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
|
|
|
- _PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
|
- global _PIPELINE_GLOBAL_RANKS
|
|
|
- _PIPELINE_GLOBAL_RANKS = None
|
|
|
- from aphrodite.distributed.device_communicators import pynccl_utils
|
|
|
- # Destroy the pynccl states if any.
|
|
|
- pynccl_utils.destroy_process_group()
|
|
|
-
|
|
|
-
|
|
|
-# Whether to use pynccl for nccl all reduce.
|
|
|
-# We use pynccl for all reduce when using CUDA graph, because torch.distributed
|
|
|
-# is not well supported by CUDA graph.
|
|
|
-_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
|
|
|
-
|
|
|
-
|
|
|
-@contextlib.contextmanager
|
|
|
-def with_pynccl_for_all_reduce():
|
|
|
- """use Pynccl instead of torch.distributed for all reduce"""
|
|
|
- from aphrodite.distributed.device_communicators import pynccl_utils
|
|
|
- tp_size = get_tensor_model_parallel_world_size()
|
|
|
- if tp_size == 1:
|
|
|
- # No-op.
|
|
|
- # NOTE: We don't initialize Pynccl when tp_size is 1.
|
|
|
- yield
|
|
|
- else:
|
|
|
- global _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
|
|
- old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
|
|
- _ENABLE_PYNCCL_FOR_ALL_REDUCE = True
|
|
|
+def get_current_tp_rank_partition_offset(total_size: int,
|
|
|
+ tp_rank: Optional[int] = None,
|
|
|
+ tp_size: Optional[int] = None,
|
|
|
+ multiple_of: int = 1) -> int:
|
|
|
+ if tp_rank is None:
|
|
|
+ tp_rank = get_tensor_model_parallel_rank()
|
|
|
+
|
|
|
+ if tp_size is None:
|
|
|
+ tp_size = get_tensor_model_parallel_world_size()
|
|
|
+
|
|
|
+ assert total_size % multiple_of == 0
|
|
|
+ total_size = total_size // multiple_of
|
|
|
+ return ((total_size // tp_size) * tp_rank +
|
|
|
+ min(total_size % tp_size, tp_rank)) * multiple_of
|
|
|
+
|
|
|
|
|
|
- stream = torch.cuda.current_stream()
|
|
|
- with pynccl_utils.set_pynccl_stream(stream):
|
|
|
- yield
|
|
|
- _ENABLE_PYNCCL_FOR_ALL_REDUCE = old
|
|
|
+def get_current_tp_rank_partition_size(total_size: int,
|
|
|
+ tp_rank: Optional[int] = None,
|
|
|
+ tp_size: Optional[int] = None,
|
|
|
+ multiple_of: int = 1) -> int:
|
|
|
+ if tp_rank is None:
|
|
|
+ tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
|
|
+ if tp_size is None:
|
|
|
+ tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
|
|
-def is_pynccl_enabled_for_all_reduce():
|
|
|
- """check if Pynccl is enabled for all reduce"""
|
|
|
- global _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
|
|
- return _ENABLE_PYNCCL_FOR_ALL_REDUCE
|
|
|
+ assert total_size % multiple_of == 0
|
|
|
+ total_size = total_size // multiple_of
|
|
|
+ return ((total_size // tp_size) +
|
|
|
+ (total_size % tp_size > tp_rank)) * multiple_of
|