|
@@ -100,14 +100,14 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|
|
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
|
|
torch._dynamo.config.cache_size_limit = 128
|
|
|
# Use persistent cache to avoid XLA recompilation.
|
|
|
- # NOTE: This does not completely eliminate the recompilation
|
|
|
- # overhead because dynamo does not cache the compiled results.
|
|
|
+ # NOTE: Set per-rank cache path since different ranks
|
|
|
+ # can have slightly different XLA graphs.
|
|
|
APHRODITE_XLA_CACHE_PATH = os.getenv("APHRODITE_XLA_CACHE_PATH",
|
|
|
"~/.aphrodite/xla_cache/")
|
|
|
- # NOTE: Set readonly=False only for the rank 0 process to avoid
|
|
|
- # race conditions.
|
|
|
- xr.initialize_cache(os.path.expanduser(APHRODITE_XLA_CACHE_PATH),
|
|
|
- readonly=not self.is_driver_worker)
|
|
|
+ world_size = self.parallel_config.world_size
|
|
|
+ per_rank_path = os.path.join(APHRODITE_XLA_CACHE_PATH,
|
|
|
+ f"tp{world_size}_rank{self.rank}")
|
|
|
+ xr.initialize_cache(per_rank_path, readonly=False)
|
|
|
|
|
|
def load_model(self):
|
|
|
self.model_runner.load_model()
|