1
0
Эх сурвалжийг харах

chore: set per-rank XLA cache for TPU (#714)

AlpinDale 6 сар өмнө
parent
commit
1c519cc6ac

+ 1 - 0
aphrodite/task_handler/tpu_model_runner.py

@@ -90,6 +90,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
         load_config: LoadConfig,
         multimodal_config: Optional[MultiModalConfig] = None,
         is_driver_worker: bool = False,
+        **kwargs,
     ):
         self.model_config = model_config
         self.parallel_config = parallel_config

+ 6 - 6
aphrodite/task_handler/tpu_worker.py

@@ -100,14 +100,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
         # 30-40 graphs for decode. 128 is an arbitrary safe number.
         torch._dynamo.config.cache_size_limit = 128
         # Use persistent cache to avoid XLA recompilation.
-        # NOTE: This does not completely eliminate the recompilation
-        # overhead because dynamo does not cache the compiled results.
+        # NOTE: Set per-rank cache path since different ranks
+        # can have slightly different XLA graphs.
         APHRODITE_XLA_CACHE_PATH = os.getenv("APHRODITE_XLA_CACHE_PATH",
                                              "~/.aphrodite/xla_cache/")
-        # NOTE: Set readonly=False only for the rank 0 process to avoid
-        # race conditions.
-        xr.initialize_cache(os.path.expanduser(APHRODITE_XLA_CACHE_PATH),
-                            readonly=not self.is_driver_worker)
+        world_size = self.parallel_config.world_size
+        per_rank_path = os.path.join(APHRODITE_XLA_CACHE_PATH,
+                                     f"tp{world_size}_rank{self.rank}")
+        xr.initialize_cache(per_rank_path, readonly=False)
 
     def load_model(self):
         self.model_runner.load_model()