pynccl_utils.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import contextlib
  2. from typing import Optional
  3. import torch
  4. from loguru import logger
  5. from torch.distributed import ReduceOp
  6. try:
  7. from aphrodite.distributed.device_communicators.pynccl import (
  8. NCCLCommunicator,
  9. ncclGetVersion,
  10. )
  11. except Exception as e:
  12. # in non-NVIDIA environments, we can't import the nccl module
  13. # e.g. when running on machines with AMD GPUs
  14. logger.info(f"Failed to import NCCL library: {e}")
  15. logger.info("It is expected if you are not running on NVIDIA GPUs.")
  16. pass
  17. comm: Optional["NCCLCommunicator"] = None
  18. def is_initialized() -> bool:
  19. """Returns whether the NCCL backend is initialized."""
  20. return comm is not None
  21. @contextlib.contextmanager
  22. def set_pynccl_stream(stream: torch.cuda.Stream):
  23. """Set the cuda stream for communication"""
  24. try:
  25. comm.stream = stream
  26. yield
  27. finally:
  28. pass
  29. def init_process_group(world_size: int,
  30. rank: int,
  31. init_method: str,
  32. local_rank: int = -1) -> None:
  33. assert not is_initialized()
  34. global comm
  35. logger.info(f"Aphrodite is using nccl=={ncclGetVersion()}")
  36. comm = NCCLCommunicator(init_method=init_method,
  37. world_size=world_size,
  38. local_rank=local_rank,
  39. rank=rank)
  40. def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
  41. """All-reduces the input tensor across the process group."""
  42. assert input_.is_cuda, f"{input_} should be a cuda tensor"
  43. comm.all_reduce(input_, op)
  44. def destroy_process_group() -> None:
  45. global comm
  46. comm = None
  47. def get_world_size() -> int:
  48. """Returns the world size."""
  49. return comm.world_size
  50. def get_nccl_backend():
  51. return comm