|
@@ -226,6 +226,7 @@ class NCCLCommunicator:
|
|
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
|
|
"NCCLCommunicator should be attached to a non-NCCL group.")
|
|
|
self.group = group
|
|
|
+ # NOTE: this rank is the rank in the group
|
|
|
self.rank = dist.get_rank(group)
|
|
|
self.world_size = dist.get_world_size(group)
|
|
|
if self.rank == 0:
|
|
@@ -233,7 +234,9 @@ class NCCLCommunicator:
|
|
|
else:
|
|
|
self.unique_id = NcclUniqueId()
|
|
|
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
|
|
- dist.broadcast(tensor, src=0, group=group)
|
|
|
+ ranks = dist.get_process_group_ranks(group)
|
|
|
+ # arg `src` in `broadcast` is the global rank
|
|
|
+ dist.broadcast(tensor, src=ranks[0], group=group)
|
|
|
byte_list = tensor.tolist()
|
|
|
for i, byte in enumerate(byte_list):
|
|
|
self.unique_id.internal[i] = byte
|