Browse Source

fix: make tensor_dict flattening/unflattening more generic

AlpinDale 7 months ago
parent
commit
bc5ac9584a
1 changed files with 28 additions and 8 deletions
  1. 28 8
      aphrodite/distributed/parallel_state.py

+ 28 - 8
aphrodite/distributed/parallel_state.py

@@ -45,14 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
 
 
 def _split_tensor_dict(
-    tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
-) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
+        tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
+        prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
     """Split the tensor dictionary into two parts:
     1. A list of (key, value) pairs. If the value is a tensor, it is replaced
          by its metadata.
     2. A list of tensors.
+    If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
+    metadata will be "key1%key2".
     """
-    metadata_list = []
+    metadata_list: List[Tuple[str, Any]] = []
     tensor_list = []
     for key, value in tensor_dict.items():
         if isinstance(value, torch.Tensor):
@@ -62,13 +64,31 @@ def _split_tensor_dict(
             # receiving side will set the device index.
             device = value.device.type
             metadata_list.append(
-                (key, TensorMetadata(device, value.dtype, value.size())))
+                (prefix + key, TensorMetadata(device, value.dtype,
+                                              value.size())))
             tensor_list.append(value)
+        elif isinstance(value, dict):
+            if len(value) == 0:
+                metadata_list.append((prefix + key, value))
+            inner_metadata_list, inner_tensor_list = _split_tensor_dict(
+                value, prefix + key + "%")
+            metadata_list.extend(inner_metadata_list)
+            tensor_list.extend(inner_tensor_list)
         else:
-            metadata_list.append((key, value))
+            metadata_list.append((prefix + key, value))
     return metadata_list, tensor_list
 
 
+def _update_nested_dict(nested_dict, flattened_key, value):
+    key_splits = flattened_key.split("%")
+    cur_dict = nested_dict
+    for k in key_splits[:-1]:
+        if k not in cur_dict:
+            cur_dict[k] = {}
+        cur_dict = cur_dict[k]
+    cur_dict[key_splits[-1]] = value
+
+
 class GroupCoordinator:
     """
     PyTorch ProcessGroup wrapper for a group of processes.
@@ -512,7 +532,7 @@ class GroupCoordinator:
                                          device=value.device)
                     if tensor.numel() == 0:
                         # Skip broadcasting empty tensors.
-                        tensor_dict[key] = tensor
+                        _update_nested_dict(tensor_dict, key, tensor)
                         continue
                     if tensor.is_cpu:
                         # use metadata_group for CPU tensors
@@ -528,9 +548,9 @@ class GroupCoordinator:
                                                              group=group,
                                                              async_op=True)
                     async_handles.append(handle)
-                    tensor_dict[key] = tensor
+                    _update_nested_dict(tensor_dict, key, tensor)
                 else:
-                    tensor_dict[key] = value
+                    _update_nested_dict(tensor_dict, key, value)
             for async_handle in async_handles:
                 async_handle.wait()
         return tensor_dict