123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- from collections import namedtuple
- from typing import Any, Dict, List, Optional, Union
- import torch
- from torch.distributed import ProcessGroup
- from .parallel_state import (
- get_tensor_model_parallel_group,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- is_pynccl_enabled_for_all_reduce,
- )
- def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
- """All-reduce the input tensor across model parallel group.
- NOTE: This operation will be applied in-place on the input tensor if
- disable_custom_all_reduce is set to True. Otherwise, this operation may or
- may not be applied in place depending on whether custom all reduce is
- invoked for a particular tensor, which further depends on the tensor size
- and GPU topology.
- TLDR: always assume this function modifies its input, but use the return
- value as the output.
- """
- from aphrodite.distributed.device_communicators import pynccl_utils
- from aphrodite.distributed.device_communicators.custom_all_reduce import (
- custom_all_reduce)
- # Bypass the function if we are using only 1 GPU.
- if get_tensor_model_parallel_world_size() == 1:
- return input_
- out = custom_all_reduce(input_)
- if out is not None:
- return out
- if is_pynccl_enabled_for_all_reduce():
- # TODO: support multiple parallel groups.
- pynccl_utils.all_reduce(input_)
- else:
- torch.distributed.all_reduce(input_,
- group=get_tensor_model_parallel_group())
- return input_
- def tensor_model_parallel_all_gather(input_: torch.Tensor,
- dim: int = -1) -> torch.Tensor:
- """All-gather the input tensor across model parallel group."""
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
- assert -input_.dim() <= dim < input_.dim(), (
- f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
- if dim < 0:
- # Convert negative dim to positive.
- dim += input_.dim()
- input_size = input_.size()
- # Allocate output tensor.
- output_tensor = torch.empty((world_size, ) + input_size,
- dtype=input_.dtype,
- device=input_.device)
- # All-gather.
- torch.distributed.all_gather_into_tensor(
- output_tensor, input_, group=get_tensor_model_parallel_group())
- # Reshape
- output_tensor = output_tensor.movedim(0, dim)
- output_tensor = output_tensor.reshape(input_size[:dim] +
- (world_size * input_size[dim], ) +
- input_size[dim + 1:])
- return output_tensor
- def tensor_model_parallel_gather(input_: torch.Tensor,
- dst: int = 0,
- dim: int = -1) -> torch.Tensor:
- """Gather the input tensor across model parallel group.
- NOTE: We assume that the input tensor is on the same device across
- all the ranks.
- """
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
- assert -input_.dim() <= dim < input_.dim(), (
- f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
- if dim < 0:
- # Convert negative dim to positive.
- dim += input_.dim()
- # Allocate output tensor.
- if get_tensor_model_parallel_rank() == dst:
- gather_list = [torch.empty_like(input_) for _ in range(world_size)]
- else:
- gather_list = None
- # Gather.
- torch.distributed.gather(input_,
- gather_list,
- dst=dst,
- group=get_tensor_model_parallel_group())
- if get_tensor_model_parallel_rank() == dst:
- output_tensor = torch.cat(gather_list, dim=dim)
- else:
- output_tensor = None
- return output_tensor
- def broadcast(input_: torch.Tensor,
- src: int = 0,
- group: Optional[ProcessGroup] = None):
- """Broadcast the input tensor."""
- group = group or torch.distributed.group.WORLD
- ranks = torch.distributed.get_process_group_ranks(group)
- assert src in ranks, f"Invalid src rank ({src})"
- # Bypass the function if we are using only 1 GPU.
- world_size = torch.distributed.get_world_size(group=group)
- if world_size == 1:
- return input_
- # Broadcast.
- torch.distributed.broadcast(input_, src=src, group=group)
- return input_
- def broadcast_object_list(obj_list: List[Any],
- src: int = 0,
- group: Optional[ProcessGroup] = None):
- """Broadcast the input object list."""
- group = group or torch.distributed.group.WORLD
- ranks = torch.distributed.get_process_group_ranks(group)
- assert src in ranks, f"Invalid src rank ({src})"
- # Bypass the function if we are using only 1 GPU.
- world_size = torch.distributed.get_world_size(group=group)
- if world_size == 1:
- return obj_list
- # Broadcast.
- torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
- return obj_list
- TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
- def broadcast_tensor_dict(
- tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
- src: int = 0,
- group: Optional[ProcessGroup] = None,
- ) -> Dict[Any, Union[torch.Tensor, Any]]:
- """Broadcast the input tensor dictionary."""
- group = group or torch.distributed.group.WORLD
- ranks = torch.distributed.get_process_group_ranks(group)
- assert src in ranks, f"Invalid src rank ({src})"
- # Bypass the function if we are using only 1 GPU.
- world_size = torch.distributed.get_world_size(group=group)
- if world_size == 1:
- return tensor_dict
- rank = torch.distributed.get_rank()
- if rank == src:
- assert isinstance(
- tensor_dict,
- dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
- metadata_list = []
- for key, value in tensor_dict.items():
- if isinstance(value, torch.Tensor):
- assert value.is_cuda, (
- f"Tensor {key}: {value} is not on cuda. Currently we only "
- f"support broadcasting tensors on cuda.")
- metadata_list.append(
- (key, TensorMetadata(value.dtype, value.size())))
- else:
- metadata_list.append((key, value))
- torch.distributed.broadcast_object_list([metadata_list],
- src=src,
- group=group)
- async_handles = []
- for key, value in metadata_list:
- if isinstance(value, TensorMetadata):
- tensor = tensor_dict[key]
- async_handles.append(
- torch.distributed.broadcast(tensor,
- src=src,
- group=group,
- async_op=True))
- for async_handle in async_handles:
- async_handle.wait()
- else:
- recv_metadata_list = [None]
- torch.distributed.broadcast_object_list(recv_metadata_list,
- src=src,
- group=group)
- metadata_list = recv_metadata_list[0]
- tensor_dict = {}
- async_handles = []
- for key, value in metadata_list: # pylint: disable=not-an-iterable
- if isinstance(value, TensorMetadata):
- tensor = torch.empty(value.size,
- dtype=value.dtype,
- device="cuda")
- async_handle = torch.distributed.broadcast(tensor,
- src=src,
- async_op=True,
- group=group)
- async_handles.append(async_handle)
- tensor_dict[key] = tensor
- else:
- tensor_dict[key] = value
- for async_handle in async_handles:
- async_handle.wait()
- return tensor_dict
|