Browse Source

enable all-reduce for multiple tp groups

AlpinDale 8 months ago
parent
commit
1879e32510

+ 0 - 1
aphrodite/distributed/communication_op.py

@@ -33,7 +33,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
     if out is not None:
         return out
     if is_pynccl_enabled_for_all_reduce():
-        # TODO: support multiple parallel groups.
         pynccl_utils.all_reduce(input_)
     else:
         torch.distributed.all_reduce(input_,

+ 26 - 12
aphrodite/distributed/parallel_state.py

@@ -12,7 +12,8 @@ import torch
 from loguru import logger
 
 # Tensor model parallel group that the current rank belongs to.
-_TENSOR_MODEL_PARALLEL_GROUP = None
+_TP_DEVICE_GROUP = None
+_TP_CPU_GROUP = None
 # Pipeline model parallel group that the current rank belongs to.
 _PIPELINE_MODEL_PARALLEL_GROUP = None
 
@@ -126,15 +127,17 @@ def initialize_model_parallel(
     rank = torch.distributed.get_rank()
 
     # Build the tensor model-parallel groups.
-    global _TENSOR_MODEL_PARALLEL_GROUP
-    assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
+    global _TP_DEVICE_GROUP, _TP_CPU_GROUP
+    assert _TP_DEVICE_GROUP is None, (
         "tensor model parallel group is already initialized")
     for i in range(num_tensor_model_parallel_groups):
         ranks = range(i * tensor_model_parallel_size,
                       (i + 1) * tensor_model_parallel_size)
         group = torch.distributed.new_group(ranks, backend=backend)
+        cpu_group = torch.distributed.new_group(ranks, backend="gloo")
         if rank in ranks:
-            _TENSOR_MODEL_PARALLEL_GROUP = group
+            _TP_DEVICE_GROUP = group
+            _TP_CPU_GROUP = cpu_group
 
     # Build the pipeline model-parallel groups.
     global _PIPELINE_MODEL_PARALLEL_GROUP
@@ -179,7 +182,7 @@ def ensure_model_parallel_initialized(
 
 def model_parallel_is_initialized():
     """Check if tensor and pipeline parallel groups are initialized."""
-    return (_TENSOR_MODEL_PARALLEL_GROUP is not None
+    return (_TP_DEVICE_GROUP is not None
             and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
 
 
@@ -191,9 +194,16 @@ def get_cpu_world_group():
 
 def get_tensor_model_parallel_group():
     """Get the tensor model parallel group the caller rank belongs to."""
-    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
-        "tenosr model parallel group is not initialized")
-    return _TENSOR_MODEL_PARALLEL_GROUP
+    assert _TP_DEVICE_GROUP is not None, (
+        "tensor model parallel group is not initialized")
+    return _TP_DEVICE_GROUP
+
+
+def get_tensor_model_parallel_cpu_group():
+    """Get the tensor model parallel cpu group the caller rank belongs to."""
+    assert _TP_CPU_GROUP is not None, (
+        "tensor model parallel cpu group is not initialized")
+    return _TP_CPU_GROUP
 
 
 def get_pipeline_model_parallel_group():
@@ -271,10 +281,14 @@ def get_pipeline_model_parallel_prev_rank():
 
 def destroy_model_parallel():
     """Set the groups to none and destroy them."""
-    global _TENSOR_MODEL_PARALLEL_GROUP
-    if _TENSOR_MODEL_PARALLEL_GROUP:
-        torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
-    _TENSOR_MODEL_PARALLEL_GROUP = None
+    global _TP_DEVICE_GROUP
+    if _TP_DEVICE_GROUP:
+        torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
+    _TP_DEVICE_GROUP = None
+    global _TP_CPU_GROUP
+    if _TP_CPU_GROUP:
+        torch.distributed.destroy_process_group(_TP_CPU_GROUP)
+    _TP_CPU_GROUP = None
     global _PIPELINE_MODEL_PARALLEL_GROUP
     if _PIPELINE_MODEL_PARALLEL_GROUP:
         torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)

+ 7 - 4
aphrodite/task_handler/worker.py

@@ -13,6 +13,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.distributed import (broadcast_tensor_dict,
                                    ensure_model_parallel_initialized,
+                                   get_tensor_model_parallel_cpu_group,
                                    init_distributed_environment)
 from aphrodite.distributed.device_communicators import pynccl_utils
 from aphrodite.distributed.device_communicators.custom_all_reduce import \
@@ -289,6 +290,9 @@ def init_worker_distributed_environment(
     init_distributed_environment(parallel_config.world_size, rank,
                                  distributed_init_method, local_rank)
 
+    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
+                                      parallel_config.pipeline_parallel_size)
+
     if pynccl_utils.is_initialized():
         pynccl_world_size = pynccl_utils.get_world_size()
         if pynccl_world_size != parallel_config.world_size:
@@ -299,10 +303,9 @@ def init_worker_distributed_environment(
     elif parallel_config.world_size > 1:
         # NOTE: We don't initialize pynccl process group when world size
         # is 1.
-        pynccl_utils.init_process_group()
-
-    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
-                                      parallel_config.pipeline_parallel_size)
+        # NOTE: By default, pynccl is initialized for tp group.
+        pynccl_utils.init_process_group(
+            group=get_tensor_model_parallel_cpu_group())
 
     # Initialize a custom fast all-reduce implementation.
     if not parallel_config.disable_custom_all_reduce: