pynccl.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from contextlib import contextmanager
  2. from typing import Optional, Union
  3. # ===================== import region =====================
  4. import torch
  5. import torch.distributed as dist
  6. from loguru import logger
  7. from torch.distributed import ProcessGroup, ReduceOp
  8. from aphrodite.distributed.device_communicators.pynccl_wrapper import (
  9. NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
  10. ncclRedOpTypeEnum, ncclUniqueId)
  11. from aphrodite.distributed.parallel_state import (get_cpu_world_group,
  12. get_local_rank)
  13. class PyNcclCommunicator:
  14. def __init__(
  15. self,
  16. group: Optional[ProcessGroup] = None,
  17. device: Optional[Union[int, str, torch.device]] = None,
  18. library_path: Optional[str] = None,
  19. ):
  20. """
  21. Args:
  22. group: the process group to work on. If None, it will use the
  23. default process group.
  24. device: the device to bind the PyNcclCommunicator to. If None,
  25. it will be bind to f"cuda:{local_rank}".
  26. library_path: the path to the NCCL library. If None, it will
  27. use the default library path.
  28. It is the caller's responsibility to make sure each communicator
  29. is bind to a unique device.
  30. """
  31. assert dist.is_initialized()
  32. group = get_cpu_world_group() if group is None else group
  33. assert dist.get_backend(group) != dist.Backend.NCCL, (
  34. "PyNcclCommunicator should be attached to a non-NCCL group.")
  35. self.group = group
  36. # note: this rank is the rank in the group
  37. self.rank = dist.get_rank(group)
  38. self.world_size = dist.get_world_size(group)
  39. # if world_size == 1, no need to create communicator
  40. if self.world_size == 1:
  41. self.available = False
  42. self.disabled = True
  43. self.stream = None
  44. return
  45. try:
  46. self.nccl = NCCLLibrary(library_path)
  47. except Exception:
  48. # disable because of missing NCCL library
  49. # e.g. in a non-GPU environment
  50. self.available = False
  51. self.disabled = True
  52. self.stream = None
  53. return
  54. self.available = True
  55. self.disabled = False
  56. logger.info(f"Aphrodite is using nccl=={self.nccl.ncclGetVersion()}")
  57. if self.rank == 0:
  58. # get the unique id from NCCL
  59. self.unique_id = self.nccl.ncclGetUniqueId()
  60. else:
  61. # construct an empty unique id
  62. self.unique_id = ncclUniqueId()
  63. tensor = torch.ByteTensor(list(self.unique_id.internal))
  64. ranks = dist.get_process_group_ranks(group)
  65. # arg `src` in `broadcast` is the global rank
  66. dist.broadcast(tensor, src=ranks[0], group=group)
  67. byte_list = tensor.tolist()
  68. for i, byte in enumerate(byte_list):
  69. self.unique_id.internal[i] = byte
  70. if device is None:
  71. local_rank = get_local_rank()
  72. device = torch.device(f"cuda:{local_rank}")
  73. elif isinstance(device, int):
  74. device = torch.device(f"cuda:{device}")
  75. elif isinstance(device, str):
  76. device = torch.device(device)
  77. # now `device` is a `torch.device` object
  78. assert isinstance(device, torch.device)
  79. self.device = device
  80. # nccl communicator and stream will use this device
  81. # `torch.cuda.device` is a context manager that changes the
  82. # current cuda device to the specified one
  83. with torch.cuda.device(device):
  84. self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
  85. self.world_size, self.unique_id, self.rank)
  86. self.stream = torch.cuda.Stream()
  87. # A small all_reduce for warmup.
  88. data = torch.zeros(1, device=device)
  89. self.all_reduce(data)
  90. self.stream.synchronize()
  91. del data
  92. # by default it is disabled, e.g. in profiling models and prefill phase.
  93. # to use it, use under `with obj.change_state(enable=True)`, usually
  94. # when we are using CUDA graph.
  95. self.disabled = True
  96. def all_reduce(self,
  97. tensor: torch.Tensor,
  98. op: ReduceOp = ReduceOp.SUM,
  99. stream=None):
  100. if self.disabled:
  101. return
  102. # nccl communicator created on a specific device
  103. # will only work on tensors on the same device
  104. # otherwise it will cause "illegal memory access"
  105. assert tensor.device == self.device, (
  106. f"this nccl communicator is created to work on {self.device}, "
  107. f"but the input tensor is on {tensor.device}")
  108. if stream is None:
  109. stream = self.stream
  110. self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
  111. buffer_type(tensor.data_ptr()), tensor.numel(),
  112. ncclDataTypeEnum.from_torch(tensor.dtype),
  113. ncclRedOpTypeEnum.from_torch(op), self.comm,
  114. cudaStream_t(stream.cuda_stream))
  115. @contextmanager
  116. def change_state(self,
  117. enable: Optional[bool] = None,
  118. stream: Optional[torch.cuda.Stream] = None):
  119. """
  120. A context manager to change the state of the communicator.
  121. """
  122. if enable is None:
  123. # guess a default value when not specified
  124. enable = self.available
  125. if stream is None:
  126. stream = self.stream
  127. old_disable = self.disabled
  128. old_stream = self.stream
  129. self.stream = stream
  130. self.disabled = not enable
  131. yield
  132. self.disabled = old_disable
  133. self.stream = old_stream