123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- """CuPy utilities for all-reduce.
- We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
- CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
- CUDA graphs.
- NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
- TODO: Remove this file when torch.distributed.all_reduce is fixed.
- """
- import contextlib
- import torch
- from torch.distributed import ReduceOp
- try:
- import cupy
- from cupy.cuda import nccl
- from cupyx.distributed import NCCLBackend
- except ImportError as e:
- cupy = e
- nccl = None
- class NCCLBackend:
- ...
- _OP_MAPPING = {
- ReduceOp.SUM: "sum",
- ReduceOp.PRODUCT: "prod",
- ReduceOp.MIN: "min",
- ReduceOp.MAX: "max",
- }
- class NCCLBackendWithBFloat16(NCCLBackend):
- # This is enough to add bfloat16 support for most operations,
- # but broadcast will fail (will require changes in compiled
- # cupy code).
- def _get_nccl_dtype_and_count(self, array, count=None):
- nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
- torch_dtype = getattr(array, "_torch_dtype", None)
- if torch_dtype is torch.bfloat16:
- nccl_dtype = nccl.NCCL_BFLOAT16
- return nccl_dtype, count
- def barrier(self) -> None:
- raise RuntimeError(
- "Currently, CuPy NCCL barrier is not supported since the TCP "
- "store is immediately stopped after the initialization.")
- _NCCL_BACKEND = None
- _WORLD_SIZE = 0
- def is_initialized() -> bool:
- """Returns whether the NCCL backend is initialized."""
- return _NCCL_BACKEND is not None
- @contextlib.contextmanager
- def set_cupy_stream(stream: torch.cuda.Stream):
- """Set the cuda stream for communication"""
- cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
- stream.device_index)
- with cupy_stream:
- yield
- def init_process_group(world_size: int, rank: int, host: str,
- port: int) -> None:
- """Initializes the CuPy NCCL backend.
- # TODO: handle NCCL timeouts.
- """
- assert not is_initialized()
- if isinstance(cupy, Exception):
- raise ImportError(
- "NCCLBackend is not available. Please install cupy.") from cupy
- # TODO: Create TP and PP process groups for CuPy.
- global _NCCL_BACKEND
- global _WORLD_SIZE
- assert world_size > 0, f"{world_size=} should be a positive integer"
- assert 0 <= rank < world_size, (
- f"{rank=} should be a integer between [0, {world_size})")
- cupy.cuda.runtime.setDevice(torch.cuda.current_device())
- _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
- _WORLD_SIZE = world_size
- # Stop the TCP store to prevent the deadlock issues at termination time.
- # FIXME: This is hacky. Find a more robust solution.
- if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
- _NCCL_BACKEND._store.stop()
- def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
- """All-reduces the input tensor across the process group."""
- assert input_.is_cuda, f"{input_} should be a cuda tensor"
- # Hack to support bfloat16
- torch_dtype = input_.dtype
- if torch_dtype is torch.bfloat16:
- # We need to view as float16, otherwise
- # cupy will fail. This will not change
- # the underlying data.
- input_ = input_.view(torch.float16)
- cupy_input = cupy.asarray(input_)
- cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
- _NCCL_BACKEND.all_reduce(in_array=cupy_input,
- out_array=cupy_input,
- op=_OP_MAPPING[op])
- def destroy_process_group() -> None:
- """Destroys the NCCL backend."""
- global _NCCL_BACKEND
- global _WORLD_SIZE
- _NCCL_BACKEND = None
- _WORLD_SIZE = 0
- def get_world_size() -> int:
- """Returns the world size."""
- return _WORLD_SIZE
- def get_nccl_backend():
- return _NCCL_BACKEND
|