|
@@ -45,14 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
|
|
|
|
|
|
|
|
def _split_tensor_dict(
|
|
|
- tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
|
|
|
-) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
|
|
+ tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
|
|
|
+ prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
|
|
"""Split the tensor dictionary into two parts:
|
|
|
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
|
|
by its metadata.
|
|
|
2. A list of tensors.
|
|
|
+ If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
|
|
|
+ metadata will be "key1%key2".
|
|
|
"""
|
|
|
- metadata_list = []
|
|
|
+ metadata_list: List[Tuple[str, Any]] = []
|
|
|
tensor_list = []
|
|
|
for key, value in tensor_dict.items():
|
|
|
if isinstance(value, torch.Tensor):
|
|
@@ -62,13 +64,31 @@ def _split_tensor_dict(
|
|
|
# receiving side will set the device index.
|
|
|
device = value.device.type
|
|
|
metadata_list.append(
|
|
|
- (key, TensorMetadata(device, value.dtype, value.size())))
|
|
|
+ (prefix + key, TensorMetadata(device, value.dtype,
|
|
|
+ value.size())))
|
|
|
tensor_list.append(value)
|
|
|
+ elif isinstance(value, dict):
|
|
|
+ if len(value) == 0:
|
|
|
+ metadata_list.append((prefix + key, value))
|
|
|
+ inner_metadata_list, inner_tensor_list = _split_tensor_dict(
|
|
|
+ value, prefix + key + "%")
|
|
|
+ metadata_list.extend(inner_metadata_list)
|
|
|
+ tensor_list.extend(inner_tensor_list)
|
|
|
else:
|
|
|
- metadata_list.append((key, value))
|
|
|
+ metadata_list.append((prefix + key, value))
|
|
|
return metadata_list, tensor_list
|
|
|
|
|
|
|
|
|
+def _update_nested_dict(nested_dict, flattened_key, value):
|
|
|
+ key_splits = flattened_key.split("%")
|
|
|
+ cur_dict = nested_dict
|
|
|
+ for k in key_splits[:-1]:
|
|
|
+ if k not in cur_dict:
|
|
|
+ cur_dict[k] = {}
|
|
|
+ cur_dict = cur_dict[k]
|
|
|
+ cur_dict[key_splits[-1]] = value
|
|
|
+
|
|
|
+
|
|
|
class GroupCoordinator:
|
|
|
"""
|
|
|
PyTorch ProcessGroup wrapper for a group of processes.
|
|
@@ -512,7 +532,7 @@ class GroupCoordinator:
|
|
|
device=value.device)
|
|
|
if tensor.numel() == 0:
|
|
|
# Skip broadcasting empty tensors.
|
|
|
- tensor_dict[key] = tensor
|
|
|
+ _update_nested_dict(tensor_dict, key, tensor)
|
|
|
continue
|
|
|
if tensor.is_cpu:
|
|
|
# use metadata_group for CPU tensors
|
|
@@ -528,9 +548,9 @@ class GroupCoordinator:
|
|
|
group=group,
|
|
|
async_op=True)
|
|
|
async_handles.append(handle)
|
|
|
- tensor_dict[key] = tensor
|
|
|
+ _update_nested_dict(tensor_dict, key, tensor)
|
|
|
else:
|
|
|
- tensor_dict[key] = value
|
|
|
+ _update_nested_dict(tensor_dict, key, value)
|
|
|
for async_handle in async_handles:
|
|
|
async_handle.wait()
|
|
|
return tensor_dict
|