Răsfoiți Sursa

feat: add support for multi-host tpu (#707)

AlpinDale 6 luni în urmă
părinte
comite
36241f98af

+ 10 - 3
aphrodite/distributed/device_communicators/tpu_communicator.py

@@ -1,3 +1,4 @@
+import ray
 import torch
 import torch.distributed as dist
 from torch.distributed import ProcessGroup
@@ -18,9 +19,15 @@ class TpuCommunicator:
             return
         self.disabled = False
 
-        local_rank = dist.get_rank(group)
-        world_size = dist.get_world_size(group)
-        pjrt.initialize_multiprocess(local_rank, world_size)
+        # NOTE: When using TP > 1 on TPUs, every TPU on the same node
+        # must be used together. Therefore, the local rank and world
+        # size can be simply calculated as follows.
+        global_rank = dist.get_rank(group)
+        global_world_size = dist.get_world_size(group)
+        num_nodes = len(ray.nodes())
+        local_world_size = global_world_size // num_nodes
+        local_rank = global_rank % local_world_size
+        pjrt.initialize_multiprocess(local_rank, local_world_size)
         xr._init_world_size_ordinal()
 
     def all_reduce(self, x: torch.Tensor) -> torch.Tensor: