|
@@ -29,21 +29,17 @@ def get_attn_backend(
|
|
|
kv_cache_dtype: Optional[str],
|
|
|
block_size: int,
|
|
|
) -> Type[AttentionBackend]:
|
|
|
- backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
|
|
|
- sliding_window, dtype, kv_cache_dtype,
|
|
|
- block_size)
|
|
|
+ """Determine which attention backend to use and only import
|
|
|
+ the selected backend module.
|
|
|
+ """
|
|
|
+ backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
|
|
|
+ sliding_window, dtype, kv_cache_dtype,
|
|
|
+ block_size)
|
|
|
if backend == _Backend.FLASH_ATTN:
|
|
|
from aphrodite.attention.backends.flash_attn import \
|
|
|
FlashAttentionBackend # noqa: F401
|
|
|
- # We check it here not in _which_attn_to_use because we cannot know
|
|
|
- # the head size until we import FlashAttentionBackend.
|
|
|
- supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
|
|
- if head_size in supported_head_sizes:
|
|
|
- logger.info("Using FlashAttention-2 backend.")
|
|
|
- return FlashAttentionBackend
|
|
|
- logger.info("Cannot use FlashAttention-2 backend for head "
|
|
|
- f"size {head_size}. Using XFormers backend instead.")
|
|
|
- backend = _Backend.XFORMERS
|
|
|
+ logger.info("Using FlashAttention backend.")
|
|
|
+ return FlashAttentionBackend
|
|
|
if backend == _Backend.XFORMERS:
|
|
|
logger.info("Using XFormers backend.")
|
|
|
from aphrodite.attention.backends.xformers import \
|
|
@@ -60,14 +56,15 @@ def get_attn_backend(
|
|
|
return TorchSDPABackend
|
|
|
elif backend == _Backend.FLASHINFER:
|
|
|
logger.info("Using Flashinfer backend.")
|
|
|
- logger.warning("Eager mode is enforced for the Flashinfer backend. ")
|
|
|
+ logger.warning("Eager mode is required for the Flashinfer backend. "
|
|
|
+ "Please make sure --enforce-eager is set.")
|
|
|
from aphrodite.attention.backends.flashinfer import FlashInferBackend
|
|
|
return FlashInferBackend
|
|
|
else:
|
|
|
raise ValueError("Invalid attention backend.")
|
|
|
|
|
|
|
|
|
-def _which_attn_to_use(
|
|
|
+def which_attn_to_use(
|
|
|
num_heads: int,
|
|
|
head_size: int,
|
|
|
num_kv_heads: int,
|
|
@@ -77,59 +74,82 @@ def _which_attn_to_use(
|
|
|
block_size: int,
|
|
|
) -> _Backend:
|
|
|
"""Returns which flash attention backend to use."""
|
|
|
+
|
|
|
+ # Default case.
|
|
|
+ selected_backend = _Backend.FLASH_ATTN
|
|
|
+
|
|
|
+ # Check the environment variable and override if specified
|
|
|
+ backend_by_env_var: Optional[str] = os.getenv(APHRODITE_ATTENTION_BACKEND)
|
|
|
+ if backend_by_env_var is not None:
|
|
|
+ backend_members = _Backend.__members__
|
|
|
+ if backend_by_env_var.upper() not in backend_members:
|
|
|
+ raise ValueError(
|
|
|
+ f"Invalid attention backend '{backend_by_env_var}'. "
|
|
|
+ f"Available backends: {', '.join(backend_members)} ")
|
|
|
+ selected_backend = _Backend[backend_by_env_var.upper()]
|
|
|
if is_cpu():
|
|
|
+ if selected_backend != _Backend.TORCH_SDPA:
|
|
|
+ logger.info(f"Cannot use {selected_backend} backend on CPU.")
|
|
|
return _Backend.TORCH_SDPA
|
|
|
|
|
|
if is_hip():
|
|
|
# AMD GPUs.
|
|
|
- if torch.cuda.get_device_capability()[0] != 9:
|
|
|
- # not Instinct series GPUs.
|
|
|
- logger.info("flash_atten is not supported on NAVI GPUs.")
|
|
|
+ selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
|
|
+ == _Backend.FLASH_ATTN else selected_backend)
|
|
|
+ if selected_backend == _Backend.ROCM_FLASH:
|
|
|
+ if torch.cuda.get_device_capability()[0] != 9:
|
|
|
+ # not Instinct series GPUs.
|
|
|
+ logger.info("flash_attn is not supported on NAVI GPUs.")
|
|
|
+ else:
|
|
|
+ logger.info("f{selected_backend} is not supported in AMD GPUs.")
|
|
|
return _Backend.ROCM_FLASH
|
|
|
|
|
|
- # NVIDIA GPUs.
|
|
|
- if torch.cuda.get_device_capability()[0] < 8:
|
|
|
- # Volta and Turing NVIDIA GPUs.
|
|
|
- logger.info("Cannot use FlashAttention backend for Volta and Turing "
|
|
|
- "GPUs.")
|
|
|
- return _Backend.XFORMERS
|
|
|
-
|
|
|
- if dtype not in (torch.float16, torch.bfloat16):
|
|
|
- logger.info("Cannot use FlashAttention backend for dtype other than "
|
|
|
- "torch.float16 or torch.bfloat16.")
|
|
|
- return _Backend.XFORMERS
|
|
|
-
|
|
|
- if block_size % 16 != 0:
|
|
|
- logger.info("Cannot use FlashAttention-2 backend for block size not "
|
|
|
- "divisible by 16.")
|
|
|
- return _Backend.XFORMERS
|
|
|
-
|
|
|
- if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
|
|
- logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
|
|
- return _Backend.XFORMERS
|
|
|
-
|
|
|
- if block_size % 16 != 0:
|
|
|
- logger.info("Cannot use FlashAttention-2 backend for block size not "
|
|
|
- "divisible by 16.")
|
|
|
- return _Backend.XFORMERS
|
|
|
-
|
|
|
- if sliding_window is not None:
|
|
|
- logger.info(
|
|
|
- "Cannot use FlashAttention-2 backend due to sliding window.")
|
|
|
- return _Backend.XFORMERS
|
|
|
-
|
|
|
- try:
|
|
|
- import vllm_flash_attn # noqa: F401
|
|
|
- except ImportError:
|
|
|
- logger.info(
|
|
|
- "Cannot use FlashAttention-2 backend because the vllm_flash_attn "
|
|
|
- "package is not found. `pip install vllm-flash-attn` for better "
|
|
|
- "performance.")
|
|
|
- return _Backend.XFORMERS
|
|
|
-
|
|
|
- backend_by_env_var = os.getenv(APHRODITE_ATTENTION_BACKEND)
|
|
|
- if backend_by_env_var is not None:
|
|
|
- return _Backend[backend_by_env_var.upper()]
|
|
|
-
|
|
|
- # Default case.
|
|
|
- return _Backend.FLASH_ATTN
|
|
|
+ # FlashAttn in NVIDIA GPUs.
|
|
|
+ if selected_backend == _Backend.FLASH_ATTN:
|
|
|
+ if torch.cuda.get_device_capability()[0] < 8:
|
|
|
+ # Volta and Turing NVIDIA GPUs.
|
|
|
+ logger.info(
|
|
|
+ "Cannot use FlashAttention-2 backend for Volta and Turing "
|
|
|
+ "GPUs.")
|
|
|
+ selected_backend = _Backend.XFORMERS
|
|
|
+ elif dtype not in (torch.float16, torch.bfloat16):
|
|
|
+ logger.info(
|
|
|
+ "Cannot use FlashAttention-2 backend for dtype other than "
|
|
|
+ "torch.float16 or torch.bfloat16.")
|
|
|
+ selected_backend = _Backend.XFORMERS
|
|
|
+ elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
|
|
+ logger.info(
|
|
|
+ "Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
|
|
+ selected_backend = _Backend.XFORMERS
|
|
|
+ elif block_size % 16 != 0:
|
|
|
+ logger.info(
|
|
|
+ "Cannot use FlashAttention-2 backend for block size not "
|
|
|
+ "divisible by 16.")
|
|
|
+ selected_backend = _Backend.XFORMERS
|
|
|
+ elif sliding_window is not None:
|
|
|
+ logger.info(
|
|
|
+ "Cannot use FlashAttention-2 backend due to sliding window.")
|
|
|
+ selected_backend = _Backend.XFORMERS
|
|
|
+
|
|
|
+ # FlashAttn is valid for the model, checking if the package is installed.
|
|
|
+ if selected_backend == _Backend.FLASH_ATTN:
|
|
|
+ try:
|
|
|
+ import vllm_flash_attn # noqa: F401
|
|
|
+
|
|
|
+ from aphrodite.attention.backends.flash_attn import ( # noqa: F401
|
|
|
+ FlashAttentionBackend)
|
|
|
+
|
|
|
+ supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
|
|
+ if head_size not in supported_sizes:
|
|
|
+ logger.info(
|
|
|
+ "Cannot use FlashAttention-2 backend for head size "
|
|
|
+ f"{head_size}")
|
|
|
+ selected_backend = _Backend.XFORMERS
|
|
|
+ except ImportError:
|
|
|
+ logger.info(
|
|
|
+ "Cannot use FlashAttention-2 backend because the "
|
|
|
+ "vllm_flash_attn package is not found. "
|
|
|
+ "`pip install vllm-flash-attn` for better performance.")
|
|
|
+ selected_backend = _Backend.XFORMERS
|
|
|
+
|
|
|
+ return selected_backend
|