|
@@ -23,11 +23,12 @@ If you only need to use the distributed environment without model/pipeline
|
|
|
import contextlib
|
|
|
import pickle
|
|
|
import sys
|
|
|
+import weakref
|
|
|
from collections import namedtuple
|
|
|
from contextlib import contextmanager, nullcontext
|
|
|
from dataclasses import dataclass
|
|
|
from multiprocessing import shared_memory
|
|
|
-from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
import torch
|
|
@@ -71,6 +72,48 @@ def _split_tensor_dict(
|
|
|
return metadata_list, tensor_list
|
|
|
|
|
|
|
|
|
+_group_name_counter: Dict[str, int] = {}
|
|
|
+def _get_unique_name(name: str) -> str:
|
|
|
+ """Get a unique name for the group.
|
|
|
+ Example:
|
|
|
+ _get_unique_name("tp") -> "tp:0"
|
|
|
+ _get_unique_name("tp") -> "tp:1"
|
|
|
+ """
|
|
|
+ if name not in _group_name_counter:
|
|
|
+ _group_name_counter[name] = 0
|
|
|
+ newname = f"{name}:{_group_name_counter[name]}"
|
|
|
+ _group_name_counter[name] += 1
|
|
|
+ return newname
|
|
|
+_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}
|
|
|
+
|
|
|
+def _register_group(group: "GroupCoordinator") -> None:
|
|
|
+ # looks like Python 3.8 does not understand `ReferenceType`
|
|
|
+ _groups[group.unique_name] = weakref.ref(group) # type: ignore
|
|
|
+
|
|
|
+@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
|
|
|
+def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
|
|
+ assert group_name in _groups, f"Group {group_name} is not found."
|
|
|
+ group = _groups[group_name]()
|
|
|
+ if group is None:
|
|
|
+ raise ValueError(f"Group {group_name} is destroyed.")
|
|
|
+ group._all_reduce(tensor)
|
|
|
+
|
|
|
+@inplace_all_reduce.register_fake
|
|
|
+def _(tensor: torch.Tensor, group_name: str) -> None:
|
|
|
+ return
|
|
|
+@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
|
|
|
+def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
|
|
+ assert group_name in _groups, f"Group {group_name} is not found."
|
|
|
+ group = _groups[group_name]()
|
|
|
+ if group is None:
|
|
|
+ raise ValueError(f"Group {group_name} is destroyed.")
|
|
|
+ return group._all_reduce(tensor)
|
|
|
+
|
|
|
+@outplace_all_reduce.register_fake
|
|
|
+def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
|
|
+ return torch.empty_like(tensor)
|
|
|
+
|
|
|
+
|
|
|
class GroupCoordinator:
|
|
|
"""
|
|
|
PyTorch ProcessGroup wrapper for a group of processes.
|
|
@@ -113,7 +156,11 @@ class GroupCoordinator:
|
|
|
use_custom_allreduce: bool,
|
|
|
use_tpu_communicator: bool,
|
|
|
use_message_queue_broadcaster: bool = False,
|
|
|
+ group_name: Optional[str] = None,
|
|
|
):
|
|
|
+ group_name = group_name or "anonymous"
|
|
|
+ self.unique_name = _get_unique_name(group_name)
|
|
|
+ _register_group(self)
|
|
|
|
|
|
self.rank = torch.distributed.get_rank()
|
|
|
self.local_rank = local_rank
|
|
@@ -151,28 +198,24 @@ class GroupCoordinator:
|
|
|
from aphrodite.distributed.device_communicators.pynccl import (
|
|
|
PyNcclCommunicator)
|
|
|
|
|
|
- self.pynccl_comm: Optional[PyNcclCommunicator]
|
|
|
+ self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
|
|
if use_pynccl and self.world_size > 1:
|
|
|
self.pynccl_comm = PyNcclCommunicator(
|
|
|
group=self.cpu_group,
|
|
|
device=self.device,
|
|
|
)
|
|
|
- else:
|
|
|
- self.pynccl_comm = None
|
|
|
|
|
|
- self.ca_comm: Optional[CustomAllreduce]
|
|
|
+ self.ca_comm: Optional[CustomAllreduce] = None
|
|
|
if use_custom_allreduce and self.world_size > 1:
|
|
|
# Initialize a custom fast all-reduce implementation.
|
|
|
self.ca_comm = CustomAllreduce(
|
|
|
group=self.cpu_group,
|
|
|
device=self.device,
|
|
|
)
|
|
|
- else:
|
|
|
- self.ca_comm = None
|
|
|
|
|
|
from aphrodite.distributed.device_communicators.tpu_communicator import ( # noqa: E501
|
|
|
TpuCommunicator)
|
|
|
- self.tpu_communicator: Optional[TpuCommunicator]
|
|
|
+ self.tpu_communicator: Optional[TpuCommunicator] = None
|
|
|
if use_tpu_communicator and self.world_size > 1:
|
|
|
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
|
|
|
|
@@ -266,16 +309,41 @@ class GroupCoordinator:
|
|
|
|
|
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
+ User-facing all-reduce function before we actually call the
|
|
|
+ all-reduce operation.
|
|
|
+ We need this because Dynamo does not support passing an arbitrary
|
|
|
+ object (`self` in this case) to a custom op. We need to pass the
|
|
|
+ group name as a string, and then look up the group coordinator from
|
|
|
+ the group name, dispatch the all-reduce operation to the group
|
|
|
+ coordinator.
|
|
|
+ In addition, PyTorch custom ops do not support mutation or returning
|
|
|
+ a new tensor in the same op. So we need to figure out if the op is
|
|
|
+ in-place or out-of-place ahead of time.
|
|
|
+ """
|
|
|
+ # Bypass the function if we are using only 1 GPU.
|
|
|
+ if self.world_size == 1:
|
|
|
+ return input_
|
|
|
+ if self.tpu_communicator is not None and \
|
|
|
+ not self.tpu_communicator.disabled:
|
|
|
+ # TPU handles Dynamo with its own logic.
|
|
|
+ return self._all_reduce(input_)
|
|
|
+ if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
|
|
|
+ return torch.ops.aphrodite.outplace_all_reduce(
|
|
|
+ input_, group_name=self.unique_name)
|
|
|
+ else:
|
|
|
+ torch.ops.aphrodite.inplace_all_reduce(input_,
|
|
|
+ group_name=self.unique_name)
|
|
|
+ return input_
|
|
|
+
|
|
|
+ def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ The actual all-reduce implementation.
|
|
|
NOTE: This operation will be applied in-place or out-of-place.
|
|
|
Always assume this function modifies its input, but use the return
|
|
|
value as the output.
|
|
|
"""
|
|
|
ca_comm = self.ca_comm
|
|
|
|
|
|
- # Bypass the function if we are using only 1 GPU.
|
|
|
- if self.world_size == 1:
|
|
|
- return input_
|
|
|
-
|
|
|
# For TPUs, use TPU communicator.
|
|
|
tpu_comm = self.tpu_communicator
|
|
|
if tpu_comm is not None and not tpu_comm.disabled:
|
|
@@ -760,6 +828,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
|
|
use_pynccl=False,
|
|
|
use_custom_allreduce=False,
|
|
|
use_tpu_communicator=False,
|
|
|
+ group_name="world",
|
|
|
)
|
|
|
|
|
|
|
|
@@ -769,6 +838,7 @@ def init_model_parallel_group(
|
|
|
backend: str,
|
|
|
use_custom_allreduce: Optional[bool] = None,
|
|
|
use_message_queue_broadcaster: bool = False,
|
|
|
+ group_name: Optional[str] = None,
|
|
|
) -> GroupCoordinator:
|
|
|
if use_custom_allreduce is None:
|
|
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
|
@@ -780,6 +850,7 @@ def init_model_parallel_group(
|
|
|
use_custom_allreduce=use_custom_allreduce,
|
|
|
use_tpu_communicator=True,
|
|
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
|
|
+ group_name=group_name,
|
|
|
)
|
|
|
|
|
|
|
|
@@ -935,7 +1006,8 @@ def initialize_model_parallel(
|
|
|
_TP = init_model_parallel_group(group_ranks,
|
|
|
get_world_group().local_rank,
|
|
|
backend,
|
|
|
- use_message_queue_broadcaster=True)
|
|
|
+ use_message_queue_broadcaster=True,
|
|
|
+ group_name="tp")
|
|
|
|
|
|
# Build the pipeline model-parallel groups.
|
|
|
num_pipeline_model_parallel_groups: int = (world_size //
|
|
@@ -951,7 +1023,8 @@ def initialize_model_parallel(
|
|
|
_PP = init_model_parallel_group(group_ranks,
|
|
|
get_world_group().local_rank,
|
|
|
backend,
|
|
|
- use_custom_allreduce=False)
|
|
|
+ use_custom_allreduce=False,
|
|
|
+ group_name="pp")
|
|
|
|
|
|
|
|
|
def ensure_model_parallel_initialized(
|