|
@@ -38,8 +38,8 @@ def get_attn_backend(
|
|
|
|
|
|
if is_blocksparse:
|
|
if is_blocksparse:
|
|
logger.info("Using BlocksparseFlashAttention backend.")
|
|
logger.info("Using BlocksparseFlashAttention backend.")
|
|
- from aphrodite.attention.backends.blocksparse_attn import \
|
|
|
|
- BlocksparseFlashAttentionBackend
|
|
|
|
|
|
+ from aphrodite.attention.backends.blocksparse_attn import (
|
|
|
|
+ BlocksparseFlashAttentionBackend)
|
|
return BlocksparseFlashAttentionBackend
|
|
return BlocksparseFlashAttentionBackend
|
|
"""Determine which attention backend to use and only import
|
|
"""Determine which attention backend to use and only import
|
|
the selected backend module.
|
|
the selected backend module.
|
|
@@ -48,18 +48,18 @@ def get_attn_backend(
|
|
sliding_window, dtype, kv_cache_dtype,
|
|
sliding_window, dtype, kv_cache_dtype,
|
|
block_size)
|
|
block_size)
|
|
if backend == _Backend.FLASH_ATTN:
|
|
if backend == _Backend.FLASH_ATTN:
|
|
- from aphrodite.attention.backends.flash_attn import \
|
|
|
|
- FlashAttentionBackend # noqa: F401
|
|
|
|
|
|
+ from aphrodite.attention.backends.flash_attn import ( # noqa: F401
|
|
|
|
+ FlashAttentionBackend)
|
|
return FlashAttentionBackend
|
|
return FlashAttentionBackend
|
|
if backend == _Backend.XFORMERS:
|
|
if backend == _Backend.XFORMERS:
|
|
logger.info("Using XFormers backend.")
|
|
logger.info("Using XFormers backend.")
|
|
- from aphrodite.attention.backends.xformers import \
|
|
|
|
- XFormersBackend # noqa: F401
|
|
|
|
|
|
+ from aphrodite.attention.backends.xformers import ( # noqa: F401
|
|
|
|
+ XFormersBackend)
|
|
return XFormersBackend
|
|
return XFormersBackend
|
|
elif backend == _Backend.ROCM_FLASH:
|
|
elif backend == _Backend.ROCM_FLASH:
|
|
logger.info("Using ROCmFlashAttention backend.")
|
|
logger.info("Using ROCmFlashAttention backend.")
|
|
- from aphrodite.attention.backends.rocm_flash_attn import \
|
|
|
|
- ROCmFlashAttentionBackend # noqa: F401
|
|
|
|
|
|
+ from aphrodite.attention.backends.rocm_flash_attn import ( # noqa: F401
|
|
|
|
+ ROCmFlashAttentionBackend)
|
|
return ROCmFlashAttentionBackend
|
|
return ROCmFlashAttentionBackend
|
|
elif backend == _Backend.TORCH_SDPA:
|
|
elif backend == _Backend.TORCH_SDPA:
|
|
assert is_cpu(), RuntimeError(
|
|
assert is_cpu(), RuntimeError(
|
|
@@ -69,8 +69,8 @@ def get_attn_backend(
|
|
return TorchSDPABackend
|
|
return TorchSDPABackend
|
|
elif backend == _Backend.OPENVINO:
|
|
elif backend == _Backend.OPENVINO:
|
|
logger.info("Using OpenVINO attention backend.")
|
|
logger.info("Using OpenVINO attention backend.")
|
|
- from aphrodite.attention.backends.openvino import \
|
|
|
|
- OpenVINOAttentionBackend
|
|
|
|
|
|
+ from aphrodite.attention.backends.openvino import (
|
|
|
|
+ OpenVINOAttentionBackend)
|
|
return OpenVINOAttentionBackend
|
|
return OpenVINOAttentionBackend
|
|
elif backend == _Backend.IPEX:
|
|
elif backend == _Backend.IPEX:
|
|
assert is_xpu(), RuntimeError(
|
|
assert is_xpu(), RuntimeError(
|
|
@@ -177,8 +177,8 @@ def which_attn_to_use(
|
|
try:
|
|
try:
|
|
import aphrodite_flash_attn # noqa: F401
|
|
import aphrodite_flash_attn # noqa: F401
|
|
|
|
|
|
- from aphrodite.attention.backends.flash_attn import \
|
|
|
|
- FlashAttentionBackend # noqa: F401
|
|
|
|
|
|
+ from aphrodite.attention.backends.flash_attn import ( # noqa: F401
|
|
|
|
+ FlashAttentionBackend)
|
|
|
|
|
|
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
|
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
|
if head_size not in supported_sizes:
|
|
if head_size not in supported_sizes:
|