|
@@ -59,6 +59,9 @@ def get_attn_backend(
|
|
ROCmFlashAttentionBackend # noqa: F401
|
|
ROCmFlashAttentionBackend # noqa: F401
|
|
return ROCmFlashAttentionBackend
|
|
return ROCmFlashAttentionBackend
|
|
elif backend == _Backend.TORCH_SDPA:
|
|
elif backend == _Backend.TORCH_SDPA:
|
|
|
|
+ # TODO: make XPUs work with Torch SDPA.
|
|
|
|
+ assert is_cpu(), RuntimeError(
|
|
|
|
+ "Torch SDPA backend is only used for CPU devices.")
|
|
logger.info("Using Torch SDPA backend.")
|
|
logger.info("Using Torch SDPA backend.")
|
|
from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
|
|
from aphrodite.attention.backends.torch_sdpa import TorchSDPABackend
|
|
return TorchSDPABackend
|
|
return TorchSDPABackend
|