|
@@ -33,11 +33,18 @@ def get_attn_backend(
|
|
sliding_window, dtype, kv_cache_dtype,
|
|
sliding_window, dtype, kv_cache_dtype,
|
|
block_size)
|
|
block_size)
|
|
if backend == _Backend.FLASH_ATTN:
|
|
if backend == _Backend.FLASH_ATTN:
|
|
- logger.info("Using FlashAttention backend.")
|
|
|
|
from aphrodite.attention.backends.flash_attn import \
|
|
from aphrodite.attention.backends.flash_attn import \
|
|
FlashAttentionBackend # noqa: F401
|
|
FlashAttentionBackend # noqa: F401
|
|
- return FlashAttentionBackend
|
|
|
|
- elif backend == _Backend.XFORMERS:
|
|
|
|
|
|
+ # We check it here not in _which_attn_to_use because we cannot know
|
|
|
|
+ # the head size until we import FlashAttentionBackend.
|
|
|
|
+ supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
|
|
|
+ if head_size in supported_head_sizes:
|
|
|
|
+ logger.info("Using FlashAttention-2 backend.")
|
|
|
|
+ return FlashAttentionBackend
|
|
|
|
+ logger.info("Cannot use FlashAttention-2 backend for head "
|
|
|
|
+ f"size {head_size}. Using XFormers backend instead.")
|
|
|
|
+ backend = _Backend.XFORMERS
|
|
|
|
+ if backend == _Backend.XFORMERS:
|
|
logger.info("Using XFormers backend.")
|
|
logger.info("Using XFormers backend.")
|
|
from aphrodite.attention.backends.xformers import \
|
|
from aphrodite.attention.backends.xformers import \
|
|
XFormersBackend # noqa: F401
|
|
XFormersBackend # noqa: F401
|