1
0

pynccl.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from contextlib import contextmanager
  2. from typing import Optional, Union
  3. # ===================== import region =====================
  4. import torch
  5. import torch.distributed as dist
  6. from loguru import logger
  7. from torch.distributed import ProcessGroup, ReduceOp
  8. from aphrodite.distributed.device_communicators.pynccl_wrapper import (
  9. NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
  10. ncclRedOpTypeEnum, ncclUniqueId)
  11. class PyNcclCommunicator:
  12. def __init__(
  13. self,
  14. group: ProcessGroup,
  15. device: Union[int, str, torch.device],
  16. library_path: Optional[str] = None,
  17. ):
  18. """
  19. Args:
  20. group: the process group to work on. If None, it will use the
  21. default process group.
  22. device: the device to bind the PyNcclCommunicator to. If None,
  23. it will be bind to f"cuda:{local_rank}".
  24. library_path: the path to the NCCL library. If None, it will
  25. use the default library path.
  26. It is the caller's responsibility to make sure each communicator
  27. is bind to a unique device.
  28. """
  29. assert dist.is_initialized()
  30. assert dist.get_backend(group) != dist.Backend.NCCL, (
  31. "PyNcclCommunicator should be attached to a non-NCCL group.")
  32. self.group = group
  33. # note: this rank is the rank in the group
  34. self.rank = dist.get_rank(group)
  35. self.world_size = dist.get_world_size(group)
  36. # if world_size == 1, no need to create communicator
  37. if self.world_size == 1:
  38. self.available = False
  39. self.disabled = True
  40. self.stream = None
  41. return
  42. try:
  43. self.nccl = NCCLLibrary(library_path)
  44. except Exception:
  45. # disable because of missing NCCL library
  46. # e.g. in a non-GPU environment
  47. self.available = False
  48. self.disabled = True
  49. self.stream = None
  50. return
  51. self.available = True
  52. self.disabled = False
  53. logger.debug(f"Aphrodite is using nccl=={self.nccl.ncclGetVersion()}")
  54. if self.rank == 0:
  55. # get the unique id from NCCL
  56. self.unique_id = self.nccl.ncclGetUniqueId()
  57. else:
  58. # construct an empty unique id
  59. self.unique_id = ncclUniqueId()
  60. tensor = torch.ByteTensor(list(self.unique_id.internal))
  61. ranks = dist.get_process_group_ranks(group)
  62. # arg `src` in `broadcast` is the global rank
  63. dist.broadcast(tensor, src=ranks[0], group=group)
  64. byte_list = tensor.tolist()
  65. for i, byte in enumerate(byte_list):
  66. self.unique_id.internal[i] = byte
  67. if isinstance(device, int):
  68. device = torch.device(f"cuda:{device}")
  69. elif isinstance(device, str):
  70. device = torch.device(device)
  71. # now `device` is a `torch.device` object
  72. assert isinstance(device, torch.device)
  73. self.device = device
  74. # nccl communicator and stream will use this device
  75. # `torch.cuda.device` is a context manager that changes the
  76. # current cuda device to the specified one
  77. with torch.cuda.device(device):
  78. self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
  79. self.world_size, self.unique_id, self.rank)
  80. self.stream = torch.cuda.Stream()
  81. # A small all_reduce for warmup.
  82. data = torch.zeros(1, device=device)
  83. self.all_reduce(data)
  84. self.stream.synchronize()
  85. del data
  86. # by default it is disabled, e.g. in profiling models and prefill phase.
  87. # to use it, use under `with obj.change_state(enable=True)`, usually
  88. # when we are using CUDA graph.
  89. self.disabled = True
  90. def all_reduce(self,
  91. tensor: torch.Tensor,
  92. op: ReduceOp = ReduceOp.SUM,
  93. stream=None):
  94. if self.disabled:
  95. return
  96. # nccl communicator created on a specific device
  97. # will only work on tensors on the same device
  98. # otherwise it will cause "illegal memory access"
  99. assert tensor.device == self.device, (
  100. f"this nccl communicator is created to work on {self.device}, "
  101. f"but the input tensor is on {tensor.device}")
  102. if stream is None:
  103. stream = self.stream
  104. self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
  105. buffer_type(tensor.data_ptr()), tensor.numel(),
  106. ncclDataTypeEnum.from_torch(tensor.dtype),
  107. ncclRedOpTypeEnum.from_torch(op), self.comm,
  108. cudaStream_t(stream.cuda_stream))
  109. def send(self, tensor: torch.Tensor, dst: int, stream=None):
  110. if self.disabled:
  111. return
  112. assert tensor.device == self.device, (
  113. f"this nccl communicator is created to work on {self.device}, "
  114. f"but the input tensor is on {tensor.device}")
  115. if stream is None:
  116. stream = self.stream
  117. self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
  118. ncclDataTypeEnum.from_torch(tensor.dtype), dst,
  119. self.comm, cudaStream_t(stream.cuda_stream))
  120. def recv(self, tensor: torch.Tensor, src: int, stream=None):
  121. if self.disabled:
  122. return
  123. assert tensor.device == self.device, (
  124. f"this nccl communicator is created to work on {self.device}, "
  125. f"but the input tensor is on {tensor.device}")
  126. if stream is None:
  127. stream = self.stream
  128. self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
  129. ncclDataTypeEnum.from_torch(tensor.dtype), src,
  130. self.comm, cudaStream_t(stream.cuda_stream))
  131. @contextmanager
  132. def change_state(self,
  133. enable: Optional[bool] = None,
  134. stream: Optional[torch.cuda.Stream] = None):
  135. """
  136. A context manager to change the state of the communicator.
  137. """
  138. if enable is None:
  139. # guess a default value when not specified
  140. enable = self.available
  141. if stream is None:
  142. stream = self.stream
  143. old_disable = self.disabled
  144. old_stream = self.stream
  145. self.stream = stream
  146. self.disabled = not enable
  147. yield
  148. self.disabled = old_disable
  149. self.stream = old_stream