123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- from contextlib import contextmanager
- from typing import Optional, Union
- # ===================== import region =====================
- import torch
- import torch.distributed as dist
- from loguru import logger
- from torch.distributed import ProcessGroup, ReduceOp
- from aphrodite.distributed.device_communicators.pynccl_wrapper import (
- NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
- ncclRedOpTypeEnum, ncclUniqueId)
- from aphrodite.distributed.parallel_state import (get_cpu_world_group,
- get_local_rank)
- class PyNcclCommunicator:
- def __init__(
- self,
- group: Optional[ProcessGroup] = None,
- device: Optional[Union[int, str, torch.device]] = None,
- library_path: Optional[str] = None,
- ):
- """
- Args:
- group: the process group to work on. If None, it will use the
- default process group.
- device: the device to bind the PyNcclCommunicator to. If None,
- it will be bind to f"cuda:{local_rank}".
- library_path: the path to the NCCL library. If None, it will
- use the default library path.
- It is the caller's responsibility to make sure each communicator
- is bind to a unique device.
- """
- assert dist.is_initialized()
- group = get_cpu_world_group() if group is None else group
- assert dist.get_backend(group) != dist.Backend.NCCL, (
- "PyNcclCommunicator should be attached to a non-NCCL group.")
- self.group = group
- # note: this rank is the rank in the group
- self.rank = dist.get_rank(group)
- self.world_size = dist.get_world_size(group)
- # if world_size == 1, no need to create communicator
- if self.world_size == 1:
- self.available = False
- self.disabled = True
- self.stream = None
- return
- try:
- self.nccl = NCCLLibrary(library_path)
- except Exception:
- # disable because of missing NCCL library
- # e.g. in a non-GPU environment
- self.available = False
- self.disabled = True
- self.stream = None
- return
- self.available = True
- self.disabled = False
- logger.info(f"Aphrodite is using nccl=={self.nccl.ncclGetVersion()}")
- if self.rank == 0:
- # get the unique id from NCCL
- self.unique_id = self.nccl.ncclGetUniqueId()
- else:
- # construct an empty unique id
- self.unique_id = ncclUniqueId()
- tensor = torch.ByteTensor(list(self.unique_id.internal))
- ranks = dist.get_process_group_ranks(group)
- # arg `src` in `broadcast` is the global rank
- dist.broadcast(tensor, src=ranks[0], group=group)
- byte_list = tensor.tolist()
- for i, byte in enumerate(byte_list):
- self.unique_id.internal[i] = byte
- if device is None:
- local_rank = get_local_rank()
- device = torch.device(f"cuda:{local_rank}")
- elif isinstance(device, int):
- device = torch.device(f"cuda:{device}")
- elif isinstance(device, str):
- device = torch.device(device)
- # now `device` is a `torch.device` object
- assert isinstance(device, torch.device)
- self.device = device
- # nccl communicator and stream will use this device
- # `torch.cuda.device` is a context manager that changes the
- # current cuda device to the specified one
- with torch.cuda.device(device):
- self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
- self.world_size, self.unique_id, self.rank)
- self.stream = torch.cuda.Stream()
- # A small all_reduce for warmup.
- data = torch.zeros(1, device=device)
- self.all_reduce(data)
- self.stream.synchronize()
- del data
- # by default it is disabled, e.g. in profiling models and prefill phase.
- # to use it, use under `with obj.change_state(enable=True)`, usually
- # when we are using CUDA graph.
- self.disabled = True
- def all_reduce(self,
- tensor: torch.Tensor,
- op: ReduceOp = ReduceOp.SUM,
- stream=None):
- if self.disabled:
- return
- # nccl communicator created on a specific device
- # will only work on tensors on the same device
- # otherwise it will cause "illegal memory access"
- assert tensor.device == self.device, (
- f"this nccl communicator is created to work on {self.device}, "
- f"but the input tensor is on {tensor.device}")
- if stream is None:
- stream = self.stream
- self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
- buffer_type(tensor.data_ptr()), tensor.numel(),
- ncclDataTypeEnum.from_torch(tensor.dtype),
- ncclRedOpTypeEnum.from_torch(op), self.comm,
- cudaStream_t(stream.cuda_stream))
- @contextmanager
- def change_state(self,
- enable: Optional[bool] = None,
- stream: Optional[torch.cuda.Stream] = None):
- """
- A context manager to change the state of the communicator.
- """
- if enable is None:
- # guess a default value when not specified
- enable = self.available
- if stream is None:
- stream = self.stream
- old_disable = self.disabled
- old_stream = self.stream
- self.stream = stream
- self.disabled = not enable
- yield
- self.disabled = old_disable
- self.stream = old_stream