1
0
Эх сурвалжийг харах

fix: device assertion for sdpa backend; fix env for tpu worker

AlpinDale 7 сар өмнө
parent
commit
a524667db0

+ 3 - 0
aphrodite/attention/selector.py

@@ -59,6 +59,9 @@ def get_attn_backend(
             ROCmFlashAttentionBackend  # noqa: F401
         return ROCmFlashAttentionBackend
     elif backend == _Backend.TORCH_SDPA:
+        # TODO: make XPUs work with Torch SDPA.
+        assert is_cpu(), RuntimeError(
+            "Torch SDPA backend is only used for CPU devices.")
         logger.info("Using Torch SDPA backend.")
         from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
         return TorchSDPABackend

+ 3 - 1
aphrodite/task_handler/tpu_worker.py

@@ -87,7 +87,9 @@ class TPUWorker(LoraNotSupportedWorkerBase):
         # Use persistent cache to avoid XLA recompilation.
         # NOTE: This does not completely eliminate the recompilation
         # overhead because dynamo does not cache the compiled results.
-        xr.initialize_cache(os.path.expanduser(envs.VLLM_XLA_CACHE_PATH),
+        APHRODITE_XLA_CACHE_PATH = os.getenv("APHRODITE_XLA_CACHE_PATH",
+                                             "~/.aphrodite/xla_cache/")
+        xr.initialize_cache(os.path.expanduser(APHRODITE_XLA_CACHE_PATH),
                             readonly=False)
 
     def load_model(self):