communication_op.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from collections import namedtuple
  2. from typing import Any, Dict, List, Optional, Union
  3. from torch.distributed import ProcessGroup
  4. import torch
  5. from aphrodite.modeling.megatron import cupy_utils
  6. from aphrodite.modeling.megatron.parallel_state import (
  7. get_tensor_model_parallel_rank,
  8. get_tensor_model_parallel_world_size,
  9. get_tensor_model_parallel_group,
  10. is_cupy_nccl_enabled_for_all_reduce,
  11. )
  12. from aphrodite.modeling.megatron.custom_all_reduce import custom_all_reduce
  13. def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
  14. """All-reduce the input tensor across model parallel group.
  15. NOTE: This operation will be applied in-place on the input tensor if
  16. disable_custom_all_reduce is set to True. Otherwise, this operation may or
  17. may not be applied in place depending on whether custom all reduce is
  18. invoked for a particular tensor, which further depends on the tensor size
  19. and GPU topology.
  20. TLDR: always assume this function modifies its input, but use the return
  21. value as the output.
  22. """
  23. # Bypass the function if we are using only 1 GPU.
  24. if get_tensor_model_parallel_world_size() == 1:
  25. return input_
  26. out = custom_all_reduce(input_)
  27. if out is not None:
  28. return out
  29. if is_cupy_nccl_enabled_for_all_reduce():
  30. # TODO: support multiple parallel groups.
  31. cupy_utils.all_reduce(input_)
  32. else:
  33. torch.distributed.all_reduce(input_,
  34. group=get_tensor_model_parallel_group())
  35. return input_
  36. def tensor_model_parallel_all_gather(input_: torch.Tensor,
  37. dim: int = -1) -> torch.Tensor:
  38. """All-gather the input tensor across model parallel group."""
  39. world_size = get_tensor_model_parallel_world_size()
  40. # Bypass the function if we are using only 1 GPU.
  41. if world_size == 1:
  42. return input_
  43. assert -input_.dim() <= dim < input_.dim(), (
  44. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  45. if dim < 0:
  46. # Convert negative dim to positive.
  47. dim += input_.dim()
  48. input_size = input_.size()
  49. # Allocate output tensor.
  50. output_tensor = torch.empty((world_size, ) + input_size,
  51. dtype=input_.dtype,
  52. device=input_.device)
  53. # All-gather.
  54. torch.distributed.all_gather_into_tensor(
  55. output_tensor, input_, group=get_tensor_model_parallel_group())
  56. # Reshape
  57. output_tensor = output_tensor.movedim(0, dim)
  58. output_tensor = output_tensor.reshape(input_size[:dim] +
  59. (world_size * input_size[dim], ) +
  60. input_size[dim + 1:])
  61. return output_tensor
  62. def tensor_model_parallel_gather(input_: torch.Tensor,
  63. dst: int = 0,
  64. dim: int = -1) -> torch.Tensor:
  65. """Gather the input tensor across model parallel group.
  66. NOTE: We assume that the input tensor is on the same device across
  67. all the ranks.
  68. """
  69. world_size = get_tensor_model_parallel_world_size()
  70. # Bypass the function if we are using only 1 GPU.
  71. if world_size == 1:
  72. return input_
  73. assert -input_.dim() <= dim < input_.dim(), (
  74. f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
  75. if dim < 0:
  76. # Convert negative dim to positive.
  77. dim += input_.dim()
  78. # Allocate output tensor.
  79. if get_tensor_model_parallel_rank() == dst:
  80. gather_list = [torch.empty_like(input_) for _ in range(world_size)]
  81. else:
  82. gather_list = None
  83. # Gather.
  84. torch.distributed.gather(input_,
  85. gather_list,
  86. dst=dst,
  87. group=get_tensor_model_parallel_group())
  88. if get_tensor_model_parallel_rank() == dst:
  89. output_tensor = torch.cat(gather_list, dim=dim)
  90. else:
  91. output_tensor = None
  92. return output_tensor
  93. def broadcast(input_: torch.Tensor,
  94. src: int = 0,
  95. group: Optional[ProcessGroup] = None):
  96. """Broadcast the input tensor."""
  97. group = group or torch.distributed.group.WORLD
  98. ranks = torch.distributed.get_process_group_ranks(group)
  99. assert src in ranks, f"Invalid src rank ({src})"
  100. # Bypass the function if we are using only 1 GPU.
  101. world_size = torch.distributed.get_world_size(group=group)
  102. if world_size == 1:
  103. return input_
  104. # Broadcast.
  105. torch.distributed.broadcast(input_, src=src, group=group)
  106. return input_
  107. def broadcast_object_list(obj_list: List[Any],
  108. src: int = 0,
  109. group: Optional[ProcessGroup] = None):
  110. """Broadcast the input object list."""
  111. group = group or torch.distributed.group.WORLD
  112. ranks = torch.distributed.get_process_group_ranks(group)
  113. assert src in ranks, f"Invalid src rank ({src})"
  114. # Bypass the function if we are using only 1 GPU.
  115. world_size = torch.distributed.get_world_size(group=group)
  116. if world_size == 1:
  117. return obj_list
  118. # Broadcast.
  119. torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
  120. return obj_list
  121. TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
  122. def broadcast_tensor_dict(
  123. tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
  124. src: int = 0,
  125. group: Optional[ProcessGroup] = None,
  126. ) -> Dict[Any, Union[torch.Tensor, Any]]:
  127. """Broadcast the input tensor dictionary."""
  128. group = group or torch.distributed.group.WORLD
  129. ranks = torch.distributed.get_process_group_ranks(group)
  130. assert src in ranks, f"Invalid src rank ({src})"
  131. # Bypass the function if we are using only 1 GPU.
  132. world_size = torch.distributed.get_world_size(group=group)
  133. if world_size == 1:
  134. return tensor_dict
  135. rank = torch.distributed.get_rank()
  136. if rank == src:
  137. assert isinstance(
  138. tensor_dict,
  139. dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
  140. metadata_list = []
  141. for key, value in tensor_dict.items():
  142. if isinstance(value, torch.Tensor):
  143. assert value.is_cuda, (
  144. f"Tensor {key}: {value} is not on cuda. Currently we only "
  145. f"support broadcasting tensors on cuda.")
  146. metadata_list.append(
  147. (key, TensorMetadata(value.dtype, value.size())))
  148. else:
  149. metadata_list.append((key, value))
  150. torch.distributed.broadcast_object_list([metadata_list],
  151. src=src,
  152. group=group)
  153. for key, value in metadata_list:
  154. if isinstance(value, TensorMetadata):
  155. tensor = tensor_dict[key]
  156. torch.distributed.broadcast(tensor, src=src)
  157. else:
  158. recv_metadata_list = [None]
  159. torch.distributed.broadcast_object_list(recv_metadata_list,
  160. src=src,
  161. group=group)
  162. metadata_list = recv_metadata_list[0]
  163. tensor_dict = {}
  164. async_handles = []
  165. for key, value in metadata_list: # pylint: disable=not-an-iterable
  166. if isinstance(value, TensorMetadata):
  167. tensor = torch.empty(value.size,
  168. dtype=value.dtype,
  169. device="cuda")
  170. async_handle = torch.distributed.broadcast(tensor,
  171. src=src,
  172. async_op=True,
  173. group=group)
  174. async_handles.append(async_handle)
  175. tensor_dict[key] = tensor
  176. else:
  177. tensor_dict[key] = value
  178. for async_handle in async_handles:
  179. async_handle.wait()
  180. return tensor_dict