Browse Source

fix: device assertion for sdpa backend; fix env for tpu worker

AlpinDale 7 months ago
parent
commit
a524667db0
2 changed files with 6 additions and 1 deletions
  1. 3 0
      aphrodite/attention/selector.py
  2. 3 1
      aphrodite/task_handler/tpu_worker.py

+ 3 - 0
aphrodite/attention/selector.py

@@ -59,6 +59,9 @@ def get_attn_backend(
             ROCmFlashAttentionBackend  # noqa: F401
         return ROCmFlashAttentionBackend
     elif backend == _Backend.TORCH_SDPA:
+        # TODO: make XPUs work with Torch SDPA.
+        assert is_cpu(), RuntimeError(
+            "Torch SDPA backend is only used for CPU devices.")
         logger.info("Using Torch SDPA backend.")
         from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
         return TorchSDPABackend

+ 3 - 1
aphrodite/task_handler/tpu_worker.py

@@ -87,7 +87,9 @@ class TPUWorker(LoraNotSupportedWorkerBase):
         # Use persistent cache to avoid XLA recompilation.
         # NOTE: This does not completely eliminate the recompilation
         # overhead because dynamo does not cache the compiled results.
-        xr.initialize_cache(os.path.expanduser(envs.VLLM_XLA_CACHE_PATH),
+        APHRODITE_XLA_CACHE_PATH = os.getenv("APHRODITE_XLA_CACHE_PATH",
+                                             "~/.aphrodite/xla_cache/")
+        xr.initialize_cache(os.path.expanduser(APHRODITE_XLA_CACHE_PATH),
                             readonly=False)
 
     def load_model(self):