pynccl_utils.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import contextlib
  2. from typing import Optional
  3. import torch
  4. from loguru import logger
  5. from torch.distributed import ProcessGroup, ReduceOp
  6. try:
  7. from aphrodite.distributed.device_communicators.pynccl import (
  8. NCCLCommunicator, ncclGetVersion)
  9. except Exception as e:
  10. # in non-NVIDIA environments, we can't import the nccl module
  11. # e.g. when running on machines with AMD GPUs
  12. logger.info(f"Failed to import NCCL library: {e}")
  13. logger.info("It is expected if you are not running on NVIDIA GPUs.")
  14. pass
  15. comm: Optional["NCCLCommunicator"] = None
  16. def is_initialized() -> bool:
  17. """Returns whether the NCCL backend is initialized."""
  18. return comm is not None
  19. @contextlib.contextmanager
  20. def set_pynccl_stream(stream: torch.cuda.Stream):
  21. """Set the cuda stream for communication"""
  22. try:
  23. comm.stream = stream
  24. yield
  25. finally:
  26. pass
  27. def init_process_group(group: Optional[ProcessGroup] = None) -> None:
  28. assert not is_initialized()
  29. global comm
  30. logger.info(f"Aphrodite is using nccl=={ncclGetVersion()}")
  31. comm = NCCLCommunicator(group=group)
  32. def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
  33. """All-reduces the input tensor across the process group."""
  34. assert input_.is_cuda, f"{input_} should be a cuda tensor"
  35. comm.all_reduce(input_, op)
  36. def destroy_process_group() -> None:
  37. global comm
  38. comm = None
  39. def get_world_size() -> int:
  40. """Returns the world size."""
  41. return comm.world_size
  42. def get_nccl_backend():
  43. return comm