|
@@ -1,15 +1,14 @@
|
|
|
from collections import namedtuple
|
|
|
-from typing import Any, Dict, List, Optional, Union
|
|
|
+from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
|
-from .parallel_state import (
|
|
|
- get_tensor_model_parallel_group,
|
|
|
- get_tensor_model_parallel_rank,
|
|
|
- get_tensor_model_parallel_world_size,
|
|
|
- is_pynccl_enabled_for_all_reduce,
|
|
|
-)
|
|
|
+from .parallel_state import (get_cpu_world_group,
|
|
|
+ get_tensor_model_parallel_group,
|
|
|
+ get_tensor_model_parallel_rank,
|
|
|
+ get_tensor_model_parallel_world_size,
|
|
|
+ is_pynccl_enabled_for_all_reduce)
|
|
|
|
|
|
|
|
|
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
|
@@ -24,8 +23,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
|
|
value as the output.
|
|
|
"""
|
|
|
from aphrodite.distributed.device_communicators import pynccl_utils
|
|
|
- from aphrodite.distributed.device_communicators.custom_all_reduce import (
|
|
|
- custom_all_reduce)
|
|
|
+ from aphrodite.distributed.device_communicators.custom_all_reduce import \
|
|
|
+ custom_all_reduce
|
|
|
+
|
|
|
# Bypass the function if we are using only 1 GPU.
|
|
|
if get_tensor_model_parallel_world_size() == 1:
|
|
|
return input_
|
|
@@ -140,13 +140,46 @@ def broadcast_object_list(obj_list: List[Any],
|
|
|
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
|
|
|
|
|
|
|
|
+def _split_tensor_dict(
|
|
|
+ tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
|
|
|
+) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
|
|
+ """Split the tensor dictionary into two parts:
|
|
|
+ 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
|
|
+ by its metadata.
|
|
|
+ 2. A list of tensors.
|
|
|
+ """
|
|
|
+ metadata_list = []
|
|
|
+ tensor_list = []
|
|
|
+ for key, value in tensor_dict.items():
|
|
|
+ if isinstance(value, torch.Tensor):
|
|
|
+ # Note(youkaichao): currently this only supports broadcasting
|
|
|
+ # tensors on cuda. In the future, we can add device as a field in
|
|
|
+ # TensorMetadata to support broadcasting tensors on different
|
|
|
+ # devices.
|
|
|
+ assert value.is_cuda, (
|
|
|
+ f"Tensor {key}: {value} is not on cuda. Currently we only "
|
|
|
+ f"support broadcasting tensors on cuda.")
|
|
|
+ metadata_list.append((key, TensorMetadata(value.dtype,
|
|
|
+ value.size())))
|
|
|
+ tensor_list.append(value)
|
|
|
+ else:
|
|
|
+ metadata_list.append((key, value))
|
|
|
+ return metadata_list, tensor_list
|
|
|
+
|
|
|
+
|
|
|
def broadcast_tensor_dict(
|
|
|
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
|
|
src: int = 0,
|
|
|
group: Optional[ProcessGroup] = None,
|
|
|
-) -> Dict[Any, Union[torch.Tensor, Any]]:
|
|
|
- """Broadcast the input tensor dictionary."""
|
|
|
+ metadata_group: Optional[ProcessGroup] = None
|
|
|
+) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
|
|
+ """Broadcast the input tensor dictionary.
|
|
|
+ `group` is used to broadcast the tensors, while `metadata_group` is used
|
|
|
+ to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
|
|
|
+ dtypes).
|
|
|
+ """
|
|
|
group = group or torch.distributed.group.WORLD
|
|
|
+ metadata_group = metadata_group or get_cpu_world_group()
|
|
|
ranks = torch.distributed.get_process_group_ranks(group)
|
|
|
assert src in ranks, f"Invalid src rank ({src})"
|
|
|
|
|
@@ -154,45 +187,38 @@ def broadcast_tensor_dict(
|
|
|
world_size = torch.distributed.get_world_size(group=group)
|
|
|
if world_size == 1:
|
|
|
return tensor_dict
|
|
|
-
|
|
|
rank = torch.distributed.get_rank()
|
|
|
if rank == src:
|
|
|
+ metadata_list: List[Tuple[Any, Any]] = []
|
|
|
assert isinstance(
|
|
|
tensor_dict,
|
|
|
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
|
|
- metadata_list = []
|
|
|
- for key, value in tensor_dict.items():
|
|
|
- if isinstance(value, torch.Tensor):
|
|
|
- assert value.is_cuda, (
|
|
|
- f"Tensor {key}: {value} is not on cuda. Currently we only "
|
|
|
- f"support broadcasting tensors on cuda.")
|
|
|
- metadata_list.append(
|
|
|
- (key, TensorMetadata(value.dtype, value.size())))
|
|
|
- else:
|
|
|
- metadata_list.append((key, value))
|
|
|
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
|
|
+ # `metadata_list` lives in CPU memory.
|
|
|
+ # `broadcast_object_list` involves serialization and deserialization,
|
|
|
+ # all happening on CPU. Therefore, we can use the CPU group.
|
|
|
torch.distributed.broadcast_object_list([metadata_list],
|
|
|
src=src,
|
|
|
- group=group)
|
|
|
+ group=metadata_group)
|
|
|
async_handles = []
|
|
|
- for key, value in metadata_list:
|
|
|
- if isinstance(value, TensorMetadata):
|
|
|
- tensor = tensor_dict[key]
|
|
|
- async_handles.append(
|
|
|
- torch.distributed.broadcast(tensor,
|
|
|
- src=src,
|
|
|
- group=group,
|
|
|
- async_op=True))
|
|
|
+ for tensor in tensor_list:
|
|
|
+ async_handles.append(
|
|
|
+ torch.distributed.broadcast(tensor,
|
|
|
+ src=src,
|
|
|
+ group=group,
|
|
|
+ async_op=True))
|
|
|
for async_handle in async_handles:
|
|
|
async_handle.wait()
|
|
|
+
|
|
|
else:
|
|
|
recv_metadata_list = [None]
|
|
|
torch.distributed.broadcast_object_list(recv_metadata_list,
|
|
|
src=src,
|
|
|
- group=group)
|
|
|
- metadata_list = recv_metadata_list[0]
|
|
|
+ group=metadata_group)
|
|
|
+ assert recv_metadata_list[0] is not None
|
|
|
tensor_dict = {}
|
|
|
async_handles = []
|
|
|
- for key, value in metadata_list: # pylint: disable=not-an-iterable
|
|
|
+ for key, value in recv_metadata_list[0]:
|
|
|
if isinstance(value, TensorMetadata):
|
|
|
tensor = torch.empty(value.size,
|
|
|
dtype=value.dtype,
|