Browse Source

torch.compile: register all-reduce operations as custom ops (#1050)

AlpinDale 2 months ago
parent
commit
239a8cae25

+ 0 - 6
aphrodite/_custom_ops.py

@@ -901,12 +901,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
                                                  offsets, rank, full_nvlink)
 
 
-def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
-                     full_nvlink: bool) -> bool:
-    return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
-                                                   full_nvlink)
-
-
 def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
     torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
 

+ 19 - 2
aphrodite/distributed/device_communicators/custom_all_reduce.py

@@ -32,6 +32,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
     return True
 
 
+def is_weak_contiguous(inp: torch.Tensor):
+    return inp.is_contiguous() or (inp.storage().nbytes() -
+                                   inp.storage_offset() * inp.element_size()
+                                   == inp.numel() * inp.element_size())
+
+
 class CustomAllreduce:
 
     _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
@@ -230,8 +236,19 @@ class CustomAllreduce:
         ops.register_graph_buffers(self._ptr, handles, offsets)
 
     def should_custom_ar(self, inp: torch.Tensor):
-        return ops.should_custom_ar(inp, self.max_size, self.world_size,
-                                    self.full_nvlink)
+        if self.disabled:
+            return False
+        inp_size = inp.numel() * inp.element_size()
+        # custom allreduce requires input byte size to be multiples of 16
+        if inp_size % 16 != 0:
+            return False
+        if not is_weak_contiguous(inp):
+            return False
+        # for 4 or more non NVLink-capable GPUs, custom allreduce provides
+        # little performance improvement over NCCL.
+        if self.world_size == 2 or self.full_nvlink:
+            return inp_size < self.max_size
+        return False
 
     # all reduce, assuming inp tensor is IPC registered with register_buffer,
     # or, in the context of cuda graphs, register_graph_buffers

+ 87 - 14
aphrodite/distributed/parallel_state.py

@@ -23,11 +23,12 @@ If you only need to use the distributed environment without model/pipeline
 import contextlib
 import pickle
 import sys
+import weakref
 from collections import namedtuple
 from contextlib import contextmanager, nullcontext
 from dataclasses import dataclass
 from multiprocessing import shared_memory
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 from unittest.mock import patch
 
 import torch
@@ -71,6 +72,48 @@ def _split_tensor_dict(
     return metadata_list, tensor_list
 
 
+_group_name_counter: Dict[str, int] = {}
+def _get_unique_name(name: str) -> str:
+    """Get a unique name for the group.
+    Example:
+    _get_unique_name("tp") -> "tp:0"
+    _get_unique_name("tp") -> "tp:1"
+    """
+    if name not in _group_name_counter:
+        _group_name_counter[name] = 0
+    newname = f"{name}:{_group_name_counter[name]}"
+    _group_name_counter[name] += 1
+    return newname
+_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
+
+def _register_group(group: "GroupCoordinator") -> None:
+    # looks like Python 3.8 does not understand `ReferenceType`
+    _groups[group.unique_name] = weakref.ref(group)  # type: ignore
+
+@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
+def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
+    assert group_name in _groups, f"Group {group_name} is not found."
+    group = _groups[group_name]()
+    if group is None:
+        raise ValueError(f"Group {group_name} is destroyed.")
+    group._all_reduce(tensor)
+
+@inplace_all_reduce.register_fake
+def _(tensor: torch.Tensor, group_name: str) -> None:
+    return
+@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
+def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
+    assert group_name in _groups, f"Group {group_name} is not found."
+    group = _groups[group_name]()
+    if group is None:
+        raise ValueError(f"Group {group_name} is destroyed.")
+    return group._all_reduce(tensor)
+
+@outplace_all_reduce.register_fake
+def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
+    return torch.empty_like(tensor)
+
+
 class GroupCoordinator:
     """
     PyTorch ProcessGroup wrapper for a group of processes.
@@ -113,7 +156,11 @@ class GroupCoordinator:
         use_custom_allreduce: bool,
         use_tpu_communicator: bool,
         use_message_queue_broadcaster: bool = False,
+        group_name: Optional[str] = None,
     ):
+        group_name = group_name or "anonymous"
+        self.unique_name = _get_unique_name(group_name)
+        _register_group(self)
 
         self.rank = torch.distributed.get_rank()
         self.local_rank = local_rank
@@ -151,28 +198,24 @@ class GroupCoordinator:
         from aphrodite.distributed.device_communicators.pynccl import (
             PyNcclCommunicator)
 
-        self.pynccl_comm: Optional[PyNcclCommunicator]
+        self.pynccl_comm: Optional[PyNcclCommunicator] = None
         if use_pynccl and self.world_size > 1:
             self.pynccl_comm = PyNcclCommunicator(
                 group=self.cpu_group,
                 device=self.device,
             )
-        else:
-            self.pynccl_comm = None
 
-        self.ca_comm: Optional[CustomAllreduce]
+        self.ca_comm: Optional[CustomAllreduce] = None
         if use_custom_allreduce and self.world_size > 1:
             # Initialize a custom fast all-reduce implementation.
             self.ca_comm = CustomAllreduce(
                 group=self.cpu_group,
                 device=self.device,
             )
-        else:
-            self.ca_comm = None
 
         from aphrodite.distributed.device_communicators.tpu_communicator import (  # noqa: E501
             TpuCommunicator)
-        self.tpu_communicator: Optional[TpuCommunicator]
+        self.tpu_communicator: Optional[TpuCommunicator] = None
         if use_tpu_communicator and self.world_size > 1:
             self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
 
@@ -266,16 +309,41 @@ class GroupCoordinator:
 
     def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
         """
+        User-facing all-reduce function before we actually call the
+        all-reduce operation.
+        We need this because Dynamo does not support passing an arbitrary
+        object (`self` in this case) to a custom op. We need to pass the
+         group name as a string, and then look up the group coordinator from
+         the group name, dispatch the all-reduce operation to the group
+         coordinator.
+        In addition, PyTorch custom ops do not support mutation or returning
+        a new tensor in the same op. So we need to figure out if the op is
+        in-place or out-of-place ahead of time.
+        """
+        # Bypass the function if we are using only 1 GPU.
+        if self.world_size == 1:
+            return input_
+        if self.tpu_communicator is not None and \
+            not self.tpu_communicator.disabled:
+            # TPU handles Dynamo with its own logic.
+            return self._all_reduce(input_)
+        if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
+            return torch.ops.aphrodite.outplace_all_reduce(
+                input_, group_name=self.unique_name)
+        else:
+            torch.ops.aphrodite.inplace_all_reduce(input_,
+                                              group_name=self.unique_name)
+            return input_
+
+    def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
+        """
+        The actual all-reduce implementation.
         NOTE: This operation will be applied in-place or out-of-place. 
         Always assume this function modifies its input, but use the return
         value as the output.
         """
         ca_comm = self.ca_comm
 
-        # Bypass the function if we are using only 1 GPU.
-        if self.world_size == 1:
-            return input_
-
         # For TPUs, use TPU communicator.
         tpu_comm = self.tpu_communicator
         if tpu_comm is not None and not tpu_comm.disabled:
@@ -760,6 +828,7 @@ def init_world_group(ranks: List[int], local_rank: int,
         use_pynccl=False,
         use_custom_allreduce=False,
         use_tpu_communicator=False,
+        group_name="world",
     )
 
 
@@ -769,6 +838,7 @@ def init_model_parallel_group(
     backend: str,
     use_custom_allreduce: Optional[bool] = None,
     use_message_queue_broadcaster: bool = False,
+    group_name: Optional[str] = None,
 ) -> GroupCoordinator:
     if use_custom_allreduce is None:
         use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
@@ -780,6 +850,7 @@ def init_model_parallel_group(
         use_custom_allreduce=use_custom_allreduce,
         use_tpu_communicator=True,
         use_message_queue_broadcaster=use_message_queue_broadcaster,
+        group_name=group_name,
     )
 
 
@@ -935,7 +1006,8 @@ def initialize_model_parallel(
     _TP = init_model_parallel_group(group_ranks,
                                     get_world_group().local_rank,
                                     backend,
-                                    use_message_queue_broadcaster=True)
+                                    use_message_queue_broadcaster=True,
+                                    group_name="tp")
 
     # Build the pipeline model-parallel groups.
     num_pipeline_model_parallel_groups: int = (world_size //
@@ -951,7 +1023,8 @@ def initialize_model_parallel(
     _PP = init_model_parallel_group(group_ranks,
                                     get_world_group().local_rank,
                                     backend,
-                                    use_custom_allreduce=False)
+                                    use_custom_allreduce=False,
+                                    group_name="pp")
 
 
 def ensure_model_parallel_initialized(

+ 0 - 12
kernels/all_reduce/custom_all_reduce.cu

@@ -56,18 +56,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
           t.numel() * t.element_size());
 }
 
-bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
-                      bool full_nvlink) {
-  auto inp_size = inp.numel() * inp.element_size();
-  // custom allreduce requires input byte size to be multiples of 16
-  if (inp_size % 16 != 0) return false;
-  if (!_is_weak_contiguous(inp)) return false;
-  if (world_size == 2 || full_nvlink) return inp_size <= max_size;
-  // for 4 or more non NVLink-capable GPUs, custom allreduce provides little
-  // performance improvement over NCCL.
-  return false;
-}
-
 void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                  cudaStream_t stream) {
   auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);

+ 0 - 2
kernels/ops.h

@@ -79,8 +79,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                       const std::vector<std::string>& handles,
                       const std::vector<int64_t>& offsets, int64_t rank,
                       bool full_nvlink);
-bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
-                      bool full_nvlink);
 void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
 void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                       torch::Tensor& out);

+ 0 - 4
kernels/torch_bindings.cpp

@@ -512,10 +512,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
       "str[] handles, int[] offsets, int rank, "
       "bool full_nvlink) -> int");
   custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
-  custom_ar.def(
-      "should_custom_ar(Tensor inp, int max_size, int world_size, "
-      "bool full_nvlink) -> bool");
-  custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
 
   custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
   custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);

+ 0 - 0
tests/compile/__init__.py


+ 11 - 2
tests/compile/test_full_graph.py

@@ -2,9 +2,18 @@ import os
 
 import pytest
 
+from aphrodite.common.utils import cuda_device_count_stateless
+
+from ..utils import fork_new_process_for_each_test
+
 
 @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
-def test_full_graph(model):
+@pytest.mark.parametrize("tp_size", [1, 2])
+@fork_new_process_for_each_test
+def test_full_graph(model, tp_size):
+    # Skip the test if there are not enough CUDA devices.
+    if cuda_device_count_stateless() < tp_size:
+        pytest.skip("Not enough CUDA devices for the test.")
     # make sure these models can be captured in full graph mode
     os.environ["APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
 
@@ -16,7 +25,7 @@ def test_full_graph(model):
         "The future of AI is",
     ]
     sampling_params = SamplingParams(temperature=0)
-    llm = LLM(model=model, enforce_eager=True)
+    llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size)
     outputs = llm.generate(prompts, sampling_params)
     # Print the outputs.
     for output in outputs: