cupy_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. """CuPy utilities for all-reduce.
  2. We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
  3. CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
  4. CUDA graphs.
  5. NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
  6. TODO: Remove this file when torch.distributed.all_reduce is fixed.
  7. """
  8. import contextlib
  9. import torch
  10. from torch.distributed import ReduceOp
  11. try:
  12. import cupy
  13. from cupy.cuda import nccl
  14. from cupyx.distributed import NCCLBackend
  15. except ImportError as e:
  16. cupy = e
  17. nccl = None
  18. class NCCLBackend:
  19. ...
  20. _OP_MAPPING = {
  21. ReduceOp.SUM: "sum",
  22. ReduceOp.PRODUCT: "prod",
  23. ReduceOp.MIN: "min",
  24. ReduceOp.MAX: "max",
  25. }
  26. class NCCLBackendWithBFloat16(NCCLBackend):
  27. # This is enough to add bfloat16 support for most operations,
  28. # but broadcast will fail (will require changes in compiled
  29. # cupy code).
  30. def _get_nccl_dtype_and_count(self, array, count=None):
  31. nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
  32. torch_dtype = getattr(array, "_torch_dtype", None)
  33. if torch_dtype is torch.bfloat16:
  34. nccl_dtype = nccl.NCCL_BFLOAT16
  35. return nccl_dtype, count
  36. def barrier(self) -> None:
  37. raise RuntimeError(
  38. "Currently, CuPy NCCL barrier is not supported since the TCP "
  39. "store is immediately stopped after the initialization.")
  40. _NCCL_BACKEND = None
  41. _WORLD_SIZE = 0
  42. def is_initialized() -> bool:
  43. """Returns whether the NCCL backend is initialized."""
  44. return _NCCL_BACKEND is not None
  45. @contextlib.contextmanager
  46. def set_cupy_stream(stream: torch.cuda.Stream):
  47. """Set the cuda stream for communication"""
  48. cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
  49. stream.device_index)
  50. with cupy_stream:
  51. yield
  52. def init_process_group(world_size: int, rank: int, host: str,
  53. port: int) -> None:
  54. """Initializes the CuPy NCCL backend.
  55. # TODO: handle NCCL timeouts.
  56. """
  57. assert not is_initialized()
  58. if isinstance(cupy, Exception):
  59. raise ImportError(
  60. "NCCLBackend is not available. Please install cupy.") from cupy
  61. # TODO: Create TP and PP process groups for CuPy.
  62. global _NCCL_BACKEND
  63. global _WORLD_SIZE
  64. assert world_size > 0, f"{world_size=} should be a positive integer"
  65. assert 0 <= rank < world_size, (
  66. f"{rank=} should be a integer between [0, {world_size})")
  67. cupy.cuda.runtime.setDevice(torch.cuda.current_device())
  68. _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
  69. _WORLD_SIZE = world_size
  70. # Stop the TCP store to prevent the deadlock issues at termination time.
  71. # FIXME: This is hacky. Find a more robust solution.
  72. if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
  73. _NCCL_BACKEND._store.stop()
  74. def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
  75. """All-reduces the input tensor across the process group."""
  76. assert input_.is_cuda, f"{input_} should be a cuda tensor"
  77. # Hack to support bfloat16
  78. torch_dtype = input_.dtype
  79. if torch_dtype is torch.bfloat16:
  80. # We need to view as float16, otherwise
  81. # cupy will fail. This will not change
  82. # the underlying data.
  83. input_ = input_.view(torch.float16)
  84. cupy_input = cupy.asarray(input_)
  85. cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
  86. _NCCL_BACKEND.all_reduce(in_array=cupy_input,
  87. out_array=cupy_input,
  88. op=_OP_MAPPING[op])
  89. def destroy_process_group() -> None:
  90. """Destroys the NCCL backend."""
  91. global _NCCL_BACKEND
  92. global _WORLD_SIZE
  93. _NCCL_BACKEND = None
  94. _WORLD_SIZE = 0
  95. def get_world_size() -> int:
  96. """Returns the world size."""
  97. return _WORLD_SIZE
  98. def get_nccl_backend():
  99. return _NCCL_BACKEND