Преглед изворни кода

fix: device assertion for sdpa backend; fix env for tpu worker

AlpinDale пре 7 месеци
родитељ
комит
a524667db0
2 измењених фајлова са 6 додато и 1 уклоњено
  1. 3 0
      aphrodite/attention/selector.py
  2. 3 1
      aphrodite/task_handler/tpu_worker.py

+ 3 - 0
aphrodite/attention/selector.py

@@ -59,6 +59,9 @@ def get_attn_backend(
             ROCmFlashAttentionBackend  # noqa: F401
             ROCmFlashAttentionBackend  # noqa: F401
         return ROCmFlashAttentionBackend
         return ROCmFlashAttentionBackend
     elif backend == _Backend.TORCH_SDPA:
     elif backend == _Backend.TORCH_SDPA:
+        # TODO: make XPUs work with Torch SDPA.
+        assert is_cpu(), RuntimeError(
+            "Torch SDPA backend is only used for CPU devices.")
         logger.info("Using Torch SDPA backend.")
         logger.info("Using Torch SDPA backend.")
         from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
         from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
         return TorchSDPABackend
         return TorchSDPABackend

+ 3 - 1
aphrodite/task_handler/tpu_worker.py

@@ -87,7 +87,9 @@ class TPUWorker(LoraNotSupportedWorkerBase):
         # Use persistent cache to avoid XLA recompilation.
         # Use persistent cache to avoid XLA recompilation.
         # NOTE: This does not completely eliminate the recompilation
         # NOTE: This does not completely eliminate the recompilation
         # overhead because dynamo does not cache the compiled results.
         # overhead because dynamo does not cache the compiled results.
-        xr.initialize_cache(os.path.expanduser(envs.VLLM_XLA_CACHE_PATH),
+        APHRODITE_XLA_CACHE_PATH = os.getenv("APHRODITE_XLA_CACHE_PATH",
+                                             "~/.aphrodite/xla_cache/")
+        xr.initialize_cache(os.path.expanduser(APHRODITE_XLA_CACHE_PATH),
                             readonly=False)
                             readonly=False)
 
 
     def load_model(self):
     def load_model(self):