Browse Source

tpu: support single and multi-host TPUs on GKE and RayServe (#970)

AlpinDale 2 months ago
parent
commit
61103b92d4

+ 3 - 1
aphrodite/attention/backends/pallas.py

@@ -125,7 +125,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
             raise NotImplementedError("TPU version must be 4 or higher.")
 
         self.megacore_mode = None
-        tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
+        tpu_env = torch_xla.tpu.get_tpu_env()
+        tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
+        tpu_type = tpu_type.lower()
         if "lite" not in tpu_type:
             if self.num_kv_heads % 2 == 0:
                 self.megacore_mode = "kv_head"

+ 21 - 2
aphrodite/distributed/device_communicators/tpu_communicator.py

@@ -1,3 +1,5 @@
+import os
+
 import torch
 import torch.distributed as dist
 from torch.distributed import ProcessGroup
@@ -5,11 +7,12 @@ from torch.distributed import ProcessGroup
 from aphrodite.platforms import current_platform
 
 if current_platform.is_tpu():
-    import ray
     import torch_xla.core.xla_model as xm
     import torch_xla.runtime as xr
     from torch_xla._internal import pjrt
 
+    from aphrodite.executor import ray_utils
+
 
 class TpuCommunicator:
 
@@ -24,9 +27,25 @@ class TpuCommunicator:
         # size can be simply calculated as follows.
         global_rank = dist.get_rank(group)
         global_world_size = dist.get_world_size(group)
-        num_nodes = len(ray.nodes())
+        # Calculate how many TPU nodes are in the current deployment. This
+        # is the Ray placement group if it is deployed with Ray. Default
+        # to the number of TPU nodes in the Ray cluster. The number of TPU
+        # nodes is computed by the total number of TPUs divided by the
+        # number of TPU accelerators per node, to account for clusters
+        # with both CPUs and TPUs.
+        num_nodes = ray_utils.get_num_tpu_nodes()
+        num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
+        if num_nodes_in_pg > 0:
+            num_nodes = num_nodes_in_pg
         local_world_size = global_world_size // num_nodes
         local_rank = global_rank % local_world_size
+        # Ensure environment variables are set for multihost deployments.
+        # On GKE, this is needed for libtpu and TPU driver to know which TPU
+        # chip is actually visible. Otherwise the TPU driver will fail to
+        # initialize because the number of devices would be different from
+        # the number of visible worker addresses.
+        os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
+        os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
         pjrt.initialize_multiprocess(local_rank, local_world_size)
         xr._init_world_size_ordinal()
 

+ 15 - 0
aphrodite/executor/ray_tpu_executor.py

@@ -73,6 +73,19 @@ class RayTPUExecutor(TPUExecutor):
             worker_module_name = "aphrodite.task_handler.tpu_worker"
             worker_class_name = "TPUWorker"
 
+            # GKE does not fetch environment information from metadata server
+            # and instead sets these from within the Ray process. Therefore we
+            # need to override the Ray environment variables manually.
+            override_env = {}
+            if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
+                override_env.update({
+                    "TPU_CHIPS_PER_HOST_BOUNDS":
+                    os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
+                })
+            if "TPU_HOST_BOUNDS" in os.environ:
+                override_env.update(
+                    {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
+
             worker = ray.remote(
                 num_cpus=0,
                 resources={"TPU": 1},
@@ -83,6 +96,8 @@ class RayTPUExecutor(TPUExecutor):
                 worker_class_name=worker_class_name,
                 trust_remote_code=self.model_config.trust_remote_code,
             )
+            if override_env:
+                worker.override_env_vars.remote(override_env)
 
             worker_ip = ray.get(worker.get_node_ip.remote())
             if worker_ip == driver_ip and self.driver_dummy_worker is None:

+ 27 - 0
aphrodite/executor/ray_utils.py

@@ -1,3 +1,4 @@
+import os
 import time
 from collections import defaultdict
 from typing import Dict, List, Optional, Tuple, Union
@@ -80,6 +81,9 @@ try:
                 output = self.output_encoder.encode(output)
             return output
 
+        def override_env_vars(self, vars: Dict[str, str]):
+            os.environ.update(vars)
+
     ray_import_err = None
 
 except ImportError as e:
@@ -139,6 +143,7 @@ def _verify_bundles(placement_group: "PlacementGroup",
                 "sure you have more than "
                 f"than {parallel_config.tensor_parallel_size} GPUs available "
                 "at each node.")
+
 def _wait_until_pg_ready(current_placement_group: "PlacementGroup"):
     """Wait until a placement group is ready.
     It prints the informative log messages if the placement group is
@@ -271,3 +276,25 @@ def initialize_ray_cluster(
     _verify_bundles(current_placement_group, parallel_config, device_str)
     # Set the placement group in the parallel config
     parallel_config.placement_group = current_placement_group
+
+
+def get_num_tpu_nodes() -> int:
+    from ray._private.accelerators import TPUAcceleratorManager
+    cluster_resources = ray.cluster_resources()
+    total_tpus = int(cluster_resources["TPU"])
+    tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
+    assert total_tpus % tpus_per_node == 0
+    return total_tpus // tpus_per_node
+
+def get_num_nodes_in_placement_group() -> int:
+    pg_table = ray.util.placement_group_table()
+    current_pg = ray.util.get_current_placement_group()
+    num_nodes = 0
+    if current_pg:
+        nodes_in_pg = set()
+        for pg_key, pg in pg_table.items():
+            if pg_key == current_pg.id.hex():
+                for _, node in pg["bundles_to_node_id"].items():
+                    nodes_in_pg.add(node)
+        num_nodes = len(nodes_in_pg)
+    return num_nodes

+ 1 - 1
requirements-tpu.txt

@@ -5,4 +5,4 @@
 # Dependencies for TPU
 # Currently, the TPU backend uses a nightly version of PyTorch XLA.
 # You can install the dependencies in Dockerfile.tpu.
-ray
+ray[default]