Forráskód Böngészése

broadcast metadata through cpu

AlpinDale 8 hónapja
szülő
commit
ac5b4b6aa7

+ 60 - 34
aphrodite/distributed/communication_op.py

@@ -1,15 +1,14 @@
 from collections import namedtuple
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import torch
 from torch.distributed import ProcessGroup
 
-from .parallel_state import (
-    get_tensor_model_parallel_group,
-    get_tensor_model_parallel_rank,
-    get_tensor_model_parallel_world_size,
-    is_pynccl_enabled_for_all_reduce,
-)
+from .parallel_state import (get_cpu_world_group,
+                             get_tensor_model_parallel_group,
+                             get_tensor_model_parallel_rank,
+                             get_tensor_model_parallel_world_size,
+                             is_pynccl_enabled_for_all_reduce)
 
 
 def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@@ -24,8 +23,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
     value as the output.
     """
     from aphrodite.distributed.device_communicators import pynccl_utils
-    from aphrodite.distributed.device_communicators.custom_all_reduce import (
-        custom_all_reduce)
+    from aphrodite.distributed.device_communicators.custom_all_reduce import \
+        custom_all_reduce
+
     # Bypass the function if we are using only 1 GPU.
     if get_tensor_model_parallel_world_size() == 1:
         return input_
@@ -140,13 +140,46 @@ def broadcast_object_list(obj_list: List[Any],
 TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
 
 
+def _split_tensor_dict(
+    tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
+) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
+    """Split the tensor dictionary into two parts:
+    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
+         by its metadata.
+    2. A list of tensors.
+    """
+    metadata_list = []
+    tensor_list = []
+    for key, value in tensor_dict.items():
+        if isinstance(value, torch.Tensor):
+            # Note(youkaichao): currently this only supports broadcasting
+            # tensors on cuda. In the future, we can add device as a field in
+            # TensorMetadata to support broadcasting tensors on different
+            # devices.
+            assert value.is_cuda, (
+                f"Tensor {key}: {value} is not on cuda. Currently we only "
+                f"support broadcasting tensors on cuda.")
+            metadata_list.append((key, TensorMetadata(value.dtype,
+                                                      value.size())))
+            tensor_list.append(value)
+        else:
+            metadata_list.append((key, value))
+    return metadata_list, tensor_list
+
+
 def broadcast_tensor_dict(
     tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
     src: int = 0,
     group: Optional[ProcessGroup] = None,
-) -> Dict[Any, Union[torch.Tensor, Any]]:
-    """Broadcast the input tensor dictionary."""
+    metadata_group: Optional[ProcessGroup] = None
+) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
+    """Broadcast the input tensor dictionary.
+    `group` is used to broadcast the tensors, while `metadata_group` is used
+     to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
+     dtypes).
+    """
     group = group or torch.distributed.group.WORLD
+    metadata_group = metadata_group or get_cpu_world_group()
     ranks = torch.distributed.get_process_group_ranks(group)
     assert src in ranks, f"Invalid src rank ({src})"
 
@@ -154,45 +187,38 @@ def broadcast_tensor_dict(
     world_size = torch.distributed.get_world_size(group=group)
     if world_size == 1:
         return tensor_dict
-
     rank = torch.distributed.get_rank()
     if rank == src:
+        metadata_list: List[Tuple[Any, Any]] = []
         assert isinstance(
             tensor_dict,
             dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
-        metadata_list = []
-        for key, value in tensor_dict.items():
-            if isinstance(value, torch.Tensor):
-                assert value.is_cuda, (
-                    f"Tensor {key}: {value} is not on cuda. Currently we only "
-                    f"support broadcasting tensors on cuda.")
-                metadata_list.append(
-                    (key, TensorMetadata(value.dtype, value.size())))
-            else:
-                metadata_list.append((key, value))
+        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
+        # `metadata_list` lives in CPU memory.
+        # `broadcast_object_list` involves serialization and deserialization,
+        # all happening on CPU. Therefore, we can use the CPU group.
         torch.distributed.broadcast_object_list([metadata_list],
                                                 src=src,
-                                                group=group)
+                                                group=metadata_group)
         async_handles = []
-        for key, value in metadata_list:
-            if isinstance(value, TensorMetadata):
-                tensor = tensor_dict[key]
-                async_handles.append(
-                    torch.distributed.broadcast(tensor,
-                                                src=src,
-                                                group=group,
-                                                async_op=True))
+        for tensor in tensor_list:
+            async_handles.append(
+                torch.distributed.broadcast(tensor,
+                                            src=src,
+                                            group=group,
+                                            async_op=True))
         for async_handle in async_handles:
             async_handle.wait()
+
     else:
         recv_metadata_list = [None]
         torch.distributed.broadcast_object_list(recv_metadata_list,
                                                 src=src,
-                                                group=group)
-        metadata_list = recv_metadata_list[0]
+                                                group=metadata_group)
+        assert recv_metadata_list[0] is not None
         tensor_dict = {}
         async_handles = []
-        for key, value in metadata_list:  # pylint: disable=not-an-iterable
+        for key, value in recv_metadata_list[0]:
             if isinstance(value, TensorMetadata):
                 tensor = torch.empty(value.size,
                                      dtype=value.dtype,

+ 1 - 0
aphrodite/quantization/gptq_marlin.py

@@ -25,6 +25,7 @@ GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
 GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
 GPTQ_MARLIN_SUPPORTED_SYM = [True]
 
+
 # Precompute permutations for Marlin weight and scale shuffling
 #
 # Marlin works on [16,64] tiles. The goal of the permutations