|
@@ -19,6 +19,8 @@ _TP_PYNCCL_COMMUNICATOR = None
|
|
|
_TP_CA_COMMUNICATOR = None
|
|
|
# Pipeline model parallel group that the current rank belongs to.
|
|
|
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
|
|
+_PP_CPU_GROUP: Optional[ProcessGroup] = None
|
|
|
+_PP_PYNCCL_COMMUNICATOR = None
|
|
|
|
|
|
# when people blindly call `torch.distributed.all_reduce` etc,
|
|
|
# it will use this group. It is initialized with the `backend`
|
|
@@ -52,6 +54,11 @@ def set_custom_all_reduce(enable: bool):
|
|
|
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
|
|
|
|
|
|
|
|
+def get_pp_pynccl_communicator():
|
|
|
+ global _PP_PYNCCL_COMMUNICATOR
|
|
|
+ return _PP_PYNCCL_COMMUNICATOR
|
|
|
+
|
|
|
+
|
|
|
def get_tp_pynccl_communicator():
|
|
|
global _TP_PYNCCL_COMMUNICATOR
|
|
|
return _TP_PYNCCL_COMMUNICATOR
|
|
@@ -173,10 +180,11 @@ def initialize_model_parallel(
|
|
|
|
|
|
from aphrodite.distributed.device_communicators.pynccl import \
|
|
|
PyNcclCommunicator
|
|
|
- _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
|
|
- group=_TP_CPU_GROUP,
|
|
|
- device=_LOCAL_RANK,
|
|
|
- )
|
|
|
+ if tensor_model_parallel_size > 1:
|
|
|
+ _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
|
|
+ group=_TP_CPU_GROUP,
|
|
|
+ device=_LOCAL_RANK,
|
|
|
+ )
|
|
|
|
|
|
# Initialize a custom fast all-reduce implementation.
|
|
|
if _ENABLE_CUSTOM_ALL_REDUCE:
|
|
@@ -188,17 +196,26 @@ def initialize_model_parallel(
|
|
|
)
|
|
|
|
|
|
# Build the pipeline model-parallel groups.
|
|
|
- global _PP_DEVICE_GROUP
|
|
|
+ global _PP_DEVICE_GROUP, _PP_CPU_GROUP
|
|
|
+ global _PP_PYNCCL_COMMUNICATOR
|
|
|
global _PP_GLOBAL_RANKS
|
|
|
assert _PP_DEVICE_GROUP is None, (
|
|
|
"pipeline model parallel group is already initialized")
|
|
|
for i in range(num_pipeline_model_parallel_groups):
|
|
|
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
|
|
group = torch.distributed.new_group(ranks, backend=backend)
|
|
|
+ cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
|
|
if rank in ranks:
|
|
|
_PP_DEVICE_GROUP = group
|
|
|
+ _PP_CPU_GROUP = cpu_group
|
|
|
_PP_GLOBAL_RANKS = ranks
|
|
|
|
|
|
+ if pipeline_model_parallel_size > 1:
|
|
|
+ _PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
|
|
+ group=_PP_CPU_GROUP,
|
|
|
+ device=_LOCAL_RANK,
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
def ensure_model_parallel_initialized(
|
|
|
tensor_model_parallel_size: int,
|
|
@@ -260,6 +277,13 @@ def get_pipeline_model_parallel_group():
|
|
|
return _PP_DEVICE_GROUP
|
|
|
|
|
|
|
|
|
+def get_pipeline_model_parallel_cpu_group():
|
|
|
+ """Get the pipeline model parallel cpu group the caller rank belongs to."""
|
|
|
+ assert _PP_CPU_GROUP is not None, (
|
|
|
+ "pipeline model parallel cpu group is not initialized")
|
|
|
+ return _PP_CPU_GROUP
|
|
|
+
|
|
|
+
|
|
|
def get_tensor_model_parallel_world_size():
|
|
|
"""Return world size for the tensor model parallel group."""
|
|
|
return torch.distributed.get_world_size(
|