|
@@ -1,3 +1,4 @@
|
|
|
+import ray
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
from torch.distributed import ProcessGroup
|
|
@@ -18,9 +19,15 @@ class TpuCommunicator:
|
|
|
return
|
|
|
self.disabled = False
|
|
|
|
|
|
- local_rank = dist.get_rank(group)
|
|
|
- world_size = dist.get_world_size(group)
|
|
|
- pjrt.initialize_multiprocess(local_rank, world_size)
|
|
|
+ # NOTE: When using TP > 1 on TPUs, every TPU on the same node
|
|
|
+ # must be used together. Therefore, the local rank and world
|
|
|
+ # size can be simply calculated as follows.
|
|
|
+ global_rank = dist.get_rank(group)
|
|
|
+ global_world_size = dist.get_world_size(group)
|
|
|
+ num_nodes = len(ray.nodes())
|
|
|
+ local_world_size = global_world_size // num_nodes
|
|
|
+ local_rank = global_rank % local_world_size
|
|
|
+ pjrt.initialize_multiprocess(local_rank, local_world_size)
|
|
|
xr._init_world_size_ordinal()
|
|
|
|
|
|
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|