123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- import os
- from contextlib import contextmanager
- from typing import Any, List, Optional, Union
- import torch
- import torch.distributed as dist
- from loguru import logger
- from torch.distributed import ProcessGroup
- from aphrodite.distributed.parallel_state import (
- get_local_rank, get_tensor_model_parallel_cpu_group)
- try:
- import pynvml
- from aphrodite._C import custom_ar
- @contextmanager
- def _nvml():
- try:
- pynvml.nvmlInit()
- yield
- finally:
- pynvml.nvmlShutdown()
- except ImportError:
- # For AMD GPUs
- custom_ar = None
- pynvml = None
- @contextmanager
- def _nvml():
- try:
- yield
- finally:
- pass
- @_nvml()
- def _is_full_nvlink(device_ids: List[int]) -> bool:
- """
- query if the set of gpus are fully connected by nvlink (1 hop)
- Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
- so it works on real physical device ids.
- """
- handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
- for i, handle in enumerate(handles):
- for j, peer_handle in enumerate(handles):
- if i < j:
- try:
- p2p_status = pynvml.nvmlDeviceGetP2PStatus(
- handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
- if p2p_status != pynvml.NVML_P2P_STATUS_OK:
- return False
- except pynvml.NVMLError as error:
- logger.error(
- "NVLink detection failed. This is normal if your"
- " machine has no NVLink equipped.",
- exc_info=error)
- return False
- return True
- def _can_p2p(rank: int, world_size: int) -> bool:
- from aphrodite.distributed.utils import gpu_p2p_access_check
- for i in range(world_size):
- if i == rank:
- continue
- if not gpu_p2p_access_check(rank, i):
- return False
- return True
- class CustomAllreduce:
- _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
- # max_size: max supported allreduce size
- def __init__(self,
- group: Optional[ProcessGroup] = None,
- device: Optional[Union[int, str, torch.device]] = None,
- max_size=8192 * 1024) -> None:
- """
- Args:
- group: the process group to work on. If None, it will use the
- default process group.
- device: the device to bind the CustomAllreduce to. If None,
- it will be bind to f"cuda:{local_rank}".
- It is the caller's responsibility to make sure each communicator
- is bind to a unique device, and all communicators in this group
- are in the same node.
- """
- self._IS_CAPTURING = False
- self.disabled = True
- if custom_ar is None:
- # disable because of missing custom allreduce library
- # e.g. in a non-cuda environment
- return
- group = group or get_tensor_model_parallel_cpu_group()
- self.group = group
- assert dist.get_backend(group) != dist.Backend.NCCL, (
- "CustomAllreduce should be attached to a non-NCCL group.")
- rank = dist.get_rank(group=self.group)
- world_size = dist.get_world_size(group=self.group)
- if world_size == 1:
- # No need to initialize custom allreduce for single GPU case.
- return
- if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
- logger.warning(
- "Custom allreduce is disabled due to an unsupported world"
- " size: %d. Supported world sizes: %s. To silence this "
- "warning, specify disable_custom_all_reduce=True explicitly.",
- world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
- return
- if device is None:
- local_rank = get_local_rank()
- device = torch.device(f"cuda:{local_rank}")
- elif isinstance(device, int):
- device = torch.device(f"cuda:{device}")
- elif isinstance(device, str):
- device = torch.device(device)
- # now `device` is a `torch.device` object
- assert isinstance(device, torch.device)
- self.device = device
- cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
- if cuda_visible_devices:
- device_ids = list(map(int, cuda_visible_devices.split(",")))
- else:
- device_ids = list(range(torch.cuda.device_count()))
- physical_device_id = device_ids[device.index]
- tensor = torch.tensor([physical_device_id],
- dtype=torch.int,
- device="cpu")
- gather_list = [
- torch.tensor([0], dtype=torch.int, device="cpu")
- for _ in range(world_size)
- ]
- dist.all_gather(gather_list, tensor, group=self.group)
- physical_device_ids = [t.item() for t in gather_list]
- # test nvlink first, this will filter out most of the cases
- # where custom allreduce is not supported
- # this checks hardware and driver support for NVLink
- full_nvlink = _is_full_nvlink(physical_device_ids)
- if world_size > 2 and not full_nvlink:
- logger.warning(
- "Custom allreduce is disabled because it's not supported on"
- " more than two PCIe-only GPUs. To silence this warning, "
- "specify disable_custom_all_reduce=True explicitly.")
- return
- # test P2P capability, this checks software/cudaruntime support
- # this is expensive to compute at the first time
- # then we cache the result
- if not _can_p2p(rank, world_size):
- logger.warning(
- "Custom allreduce is disabled because your platform lacks "
- "GPU P2P capability or P2P test failed. To silence this "
- "warning, specify disable_custom_all_reduce=True explicitly.")
- return
- self.disabled = False
- # buffers memory are owned by this Python class and passed to C++
- # meta data composes of two parts: meta data for synchronization
- # (256 bytes) and a temporary buffer for storing intermediate
- # allreduce results.
- self.meta = torch.zeros(custom_ar.meta_size() + max_size,
- dtype=torch.uint8,
- device=self.device)
- # This is a pre-registered IPC buffer. In eager mode, input tensors
- # are first copied into this buffer before allreduce is performed
- self.buffer = torch.empty(max_size,
- dtype=torch.uint8,
- device=self.device)
- # This is a buffer for storing the tuples of pointers pointing to
- # IPC buffers from all ranks. Each registered tuple has size of
- # 8*world_size bytes where world_size is at most 8. Allocating 8MB
- # is enough for 131072 such tuples. The largest model I've seen only
- # needs less than 10000 of registered tuples.
- self.rank_data = torch.empty(8 * 1024 * 1024,
- dtype=torch.uint8,
- device=self.device)
- self.max_size = max_size
- self.rank = rank
- self.world_size = world_size
- handles, offsets = self._get_ipc_meta(self.meta)
- self.full_nvlink = full_nvlink
- self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
- handles, offsets, rank,
- self.full_nvlink)
- self.register_buffer(self.buffer)
- @contextmanager
- def capture(self):
- """
- The main responsibility of this context manager is the
- `register_graph_buffers` call at the end of the context.
- It records all the buffer addresses used in the CUDA graph.
- """
- try:
- self._IS_CAPTURING = True
- yield
- finally:
- self._IS_CAPTURING = False
- if not self.disabled:
- self.register_graph_buffers()
- def _get_ipc_meta(self, inp: torch.Tensor):
- data = inp.untyped_storage()._share_cuda_()
- shard_data = (
- data[1], # ipc handle to base ptr
- data[3], # offset of base ptr
- )
- return self._gather_ipc_meta(shard_data)
- def _gather_ipc_meta(self, shard_data):
- # Note: don't use `[[None]] * self.world_size` here
- # because it will create a list of the same reference
- all_data: List[Optional[Any]] = [[None]
- for i in range(self.world_size)]
- all_data[self.rank][0] = shard_data
- ranks = dist.get_process_group_ranks(group=self.group)
- ranks.sort()
- for i, rank in enumerate(ranks):
- dist.broadcast_object_list(all_data[i],
- src=rank,
- group=self.group,
- device="cpu")
- # we cannot directly use `dist.all_gather_object` here
- # because it is incompatible with `gloo` backend under inference mode.
- # see https://github.com/pytorch/pytorch/issues/126032 for details.
- handles = []
- offsets = []
- for i in range(len(all_data)):
- handles.append(all_data[i][0][0]) # type: ignore
- offsets.append(all_data[i][0][1]) # type: ignore
- return handles, offsets
- def register_buffer(self, inp: torch.Tensor):
- handles, offsets = self._get_ipc_meta(inp)
- custom_ar.register_buffer(self._ptr, inp, handles, offsets)
- def register_graph_buffers(self):
- handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
- handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
- logger.info("Registering %d cuda graph addresses", len(offset))
- custom_ar.register_graph_buffers(self._ptr, handles, offsets)
- def should_custom_ar(self, inp: torch.Tensor):
- return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
- self.full_nvlink)
- # all reduce, assuming inp tensor is IPC registered with register_buffer,
- # or, in the context of cuda graphs, register_graph_buffers
- def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
- if out is None:
- out = torch.empty_like(inp)
- custom_ar.all_reduce_reg(self._ptr, inp, out)
- return out
- # all reduce, assuming inp tensor is NOT IPC registered
- def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
- if out is None:
- out = torch.empty_like(inp)
- custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
- return out
- def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
- # when custom allreduce is disabled, this will be None
- if self.disabled:
- return None
- if self._IS_CAPTURING:
- if torch.cuda.is_current_stream_capturing():
- if self.should_custom_ar(input):
- return self.all_reduce_reg(input)
- else:
- if self.should_custom_ar(input):
- # if warm up, mimic the allocation pattern
- # since custom allreduce is out-of-place
- return torch.empty_like(input)
- else:
- # note: outside of cuda graph context,
- # custom allreduce incurs a cost of cudaMemcpy, which should
- # be small(<=1% of overall latency) compared to the performance
- # gains of using custom kernels
- if self.should_custom_ar(input):
- return self.all_reduce_unreg(input)
- return None
- def close(self):
- if not self.disabled and self._ptr:
- custom_ar.dispose(self._ptr)
- self._ptr = 0
- def __del__(self):
- self.close()
|