|
@@ -12,7 +12,8 @@ import torch
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
|
|
|
# Tensor model parallel group that the current rank belongs to.
|
|
# Tensor model parallel group that the current rank belongs to.
|
|
-_TENSOR_MODEL_PARALLEL_GROUP = None
|
|
|
|
|
|
+_TP_DEVICE_GROUP = None
|
|
|
|
+_TP_CPU_GROUP = None
|
|
# Pipeline model parallel group that the current rank belongs to.
|
|
# Pipeline model parallel group that the current rank belongs to.
|
|
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
|
|
|
|
|
@@ -126,15 +127,17 @@ def initialize_model_parallel(
|
|
rank = torch.distributed.get_rank()
|
|
rank = torch.distributed.get_rank()
|
|
|
|
|
|
# Build the tensor model-parallel groups.
|
|
# Build the tensor model-parallel groups.
|
|
- global _TENSOR_MODEL_PARALLEL_GROUP
|
|
|
|
- assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
|
|
|
|
|
|
+ global _TP_DEVICE_GROUP, _TP_CPU_GROUP
|
|
|
|
+ assert _TP_DEVICE_GROUP is None, (
|
|
"tensor model parallel group is already initialized")
|
|
"tensor model parallel group is already initialized")
|
|
for i in range(num_tensor_model_parallel_groups):
|
|
for i in range(num_tensor_model_parallel_groups):
|
|
ranks = range(i * tensor_model_parallel_size,
|
|
ranks = range(i * tensor_model_parallel_size,
|
|
(i + 1) * tensor_model_parallel_size)
|
|
(i + 1) * tensor_model_parallel_size)
|
|
group = torch.distributed.new_group(ranks, backend=backend)
|
|
group = torch.distributed.new_group(ranks, backend=backend)
|
|
|
|
+ cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
|
if rank in ranks:
|
|
if rank in ranks:
|
|
- _TENSOR_MODEL_PARALLEL_GROUP = group
|
|
|
|
|
|
+ _TP_DEVICE_GROUP = group
|
|
|
|
+ _TP_CPU_GROUP = cpu_group
|
|
|
|
|
|
# Build the pipeline model-parallel groups.
|
|
# Build the pipeline model-parallel groups.
|
|
global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
@@ -179,7 +182,7 @@ def ensure_model_parallel_initialized(
|
|
|
|
|
|
def model_parallel_is_initialized():
|
|
def model_parallel_is_initialized():
|
|
"""Check if tensor and pipeline parallel groups are initialized."""
|
|
"""Check if tensor and pipeline parallel groups are initialized."""
|
|
- return (_TENSOR_MODEL_PARALLEL_GROUP is not None
|
|
|
|
|
|
+ return (_TP_DEVICE_GROUP is not None
|
|
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
|
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
|
|
|
|
|
|
|
|
|
@@ -191,9 +194,16 @@ def get_cpu_world_group():
|
|
|
|
|
|
def get_tensor_model_parallel_group():
|
|
def get_tensor_model_parallel_group():
|
|
"""Get the tensor model parallel group the caller rank belongs to."""
|
|
"""Get the tensor model parallel group the caller rank belongs to."""
|
|
- assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
|
|
|
|
- "tenosr model parallel group is not initialized")
|
|
|
|
- return _TENSOR_MODEL_PARALLEL_GROUP
|
|
|
|
|
|
+ assert _TP_DEVICE_GROUP is not None, (
|
|
|
|
+ "tensor model parallel group is not initialized")
|
|
|
|
+ return _TP_DEVICE_GROUP
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def get_tensor_model_parallel_cpu_group():
|
|
|
|
+ """Get the tensor model parallel cpu group the caller rank belongs to."""
|
|
|
|
+ assert _TP_CPU_GROUP is not None, (
|
|
|
|
+ "tensor model parallel cpu group is not initialized")
|
|
|
|
+ return _TP_CPU_GROUP
|
|
|
|
|
|
|
|
|
|
def get_pipeline_model_parallel_group():
|
|
def get_pipeline_model_parallel_group():
|
|
@@ -271,10 +281,14 @@ def get_pipeline_model_parallel_prev_rank():
|
|
|
|
|
|
def destroy_model_parallel():
|
|
def destroy_model_parallel():
|
|
"""Set the groups to none and destroy them."""
|
|
"""Set the groups to none and destroy them."""
|
|
- global _TENSOR_MODEL_PARALLEL_GROUP
|
|
|
|
- if _TENSOR_MODEL_PARALLEL_GROUP:
|
|
|
|
- torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
|
|
|
|
- _TENSOR_MODEL_PARALLEL_GROUP = None
|
|
|
|
|
|
+ global _TP_DEVICE_GROUP
|
|
|
|
+ if _TP_DEVICE_GROUP:
|
|
|
|
+ torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
|
|
|
|
+ _TP_DEVICE_GROUP = None
|
|
|
|
+ global _TP_CPU_GROUP
|
|
|
|
+ if _TP_CPU_GROUP:
|
|
|
|
+ torch.distributed.destroy_process_group(_TP_CPU_GROUP)
|
|
|
|
+ _TP_CPU_GROUP = None
|
|
global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
global _PIPELINE_MODEL_PARALLEL_GROUP
|
|
if _PIPELINE_MODEL_PARALLEL_GROUP:
|
|
if _PIPELINE_MODEL_PARALLEL_GROUP:
|
|
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
|
|
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
|