소스 검색

prioritize user selection for attention

AlpinDale 7 달 전
부모
커밋
19a959a03e
1개의 변경된 파일83개의 추가작업 그리고 63개의 파일을 삭제
  1. 83 63
      aphrodite/attention/selector.py

+ 83 - 63
aphrodite/attention/selector.py

@@ -29,21 +29,17 @@ def get_attn_backend(
     kv_cache_dtype: Optional[str],
     block_size: int,
 ) -> Type[AttentionBackend]:
-    backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
-                                 sliding_window, dtype, kv_cache_dtype,
-                                 block_size)
+    """Determine which attention backend to use and only import
+    the selected backend module.
+    """
+    backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
+                                sliding_window, dtype, kv_cache_dtype,
+                                block_size)
     if backend == _Backend.FLASH_ATTN:
         from aphrodite.attention.backends.flash_attn import \
             FlashAttentionBackend  # noqa: F401
-        # We check it here not in _which_attn_to_use because we cannot know
-        # the head size until we import FlashAttentionBackend.
-        supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
-        if head_size in supported_head_sizes:
-            logger.info("Using FlashAttention-2 backend.")
-            return FlashAttentionBackend
-        logger.info("Cannot use FlashAttention-2 backend for head "
-                    f"size {head_size}. Using XFormers backend instead.")
-        backend = _Backend.XFORMERS
+        logger.info("Using FlashAttention backend.")
+        return FlashAttentionBackend
     if backend == _Backend.XFORMERS:
         logger.info("Using XFormers backend.")
         from aphrodite.attention.backends.xformers import \
@@ -60,14 +56,15 @@ def get_attn_backend(
         return TorchSDPABackend
     elif backend == _Backend.FLASHINFER:
         logger.info("Using Flashinfer backend.")
-        logger.warning("Eager mode is enforced for the Flashinfer backend. ")
+        logger.warning("Eager mode is required for the Flashinfer backend. "
+                       "Please make sure --enforce-eager is set.")
         from aphrodite.attention.backends.flashinfer import FlashInferBackend
         return FlashInferBackend
     else:
         raise ValueError("Invalid attention backend.")
 
 
-def _which_attn_to_use(
+def which_attn_to_use(
     num_heads: int,
     head_size: int,
     num_kv_heads: int,
@@ -77,59 +74,82 @@ def _which_attn_to_use(
     block_size: int,
 ) -> _Backend:
     """Returns which flash attention backend to use."""
+
+    # Default case.
+    selected_backend = _Backend.FLASH_ATTN
+
+    # Check the environment variable and override if specified
+    backend_by_env_var: Optional[str] = os.getenv(APHRODITE_ATTENTION_BACKEND)
+    if backend_by_env_var is not None:
+        backend_members = _Backend.__members__
+        if backend_by_env_var.upper() not in backend_members:
+            raise ValueError(
+                f"Invalid attention backend '{backend_by_env_var}'. "
+                f"Available backends: {', '.join(backend_members)} ")
+        selected_backend = _Backend[backend_by_env_var.upper()]
     if is_cpu():
+        if selected_backend != _Backend.TORCH_SDPA:
+            logger.info(f"Cannot use {selected_backend} backend on CPU.")
         return _Backend.TORCH_SDPA
 
     if is_hip():
         # AMD GPUs.
-        if torch.cuda.get_device_capability()[0] != 9:
-            # not Instinct series GPUs.
-            logger.info("flash_atten is not supported on NAVI GPUs.")
+        selected_backend = (_Backend.ROCM_FLASH if selected_backend
+                            == _Backend.FLASH_ATTN else selected_backend)
+        if selected_backend == _Backend.ROCM_FLASH:
+            if torch.cuda.get_device_capability()[0] != 9:
+                # not Instinct series GPUs.
+                logger.info("flash_attn is not supported on NAVI GPUs.")
+        else:
+            logger.info("f{selected_backend} is not supported in AMD GPUs.")
         return _Backend.ROCM_FLASH
 
-    # NVIDIA GPUs.
-    if torch.cuda.get_device_capability()[0] < 8:
-        # Volta and Turing NVIDIA GPUs.
-        logger.info("Cannot use FlashAttention backend for Volta and Turing "
-                    "GPUs.")
-        return _Backend.XFORMERS
-
-    if dtype not in (torch.float16, torch.bfloat16):
-        logger.info("Cannot use FlashAttention backend for dtype other than "
-                    "torch.float16 or torch.bfloat16.")
-        return _Backend.XFORMERS
-
-    if block_size % 16 != 0:
-        logger.info("Cannot use FlashAttention-2 backend for block size not "
-                    "divisible by 16.")
-        return _Backend.XFORMERS
-
-    if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
-        logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
-        return _Backend.XFORMERS
-
-    if block_size % 16 != 0:
-        logger.info("Cannot use FlashAttention-2 backend for block size not "
-                    "divisible by 16.")
-        return _Backend.XFORMERS
-
-    if sliding_window is not None:
-        logger.info(
-            "Cannot use FlashAttention-2 backend due to sliding window.")
-        return _Backend.XFORMERS
-
-    try:
-        import vllm_flash_attn  # noqa: F401
-    except ImportError:
-        logger.info(
-            "Cannot use FlashAttention-2 backend because the vllm_flash_attn "
-            "package is not found. `pip install vllm-flash-attn` for better "
-            "performance.")
-        return _Backend.XFORMERS
-
-    backend_by_env_var = os.getenv(APHRODITE_ATTENTION_BACKEND)
-    if backend_by_env_var is not None:
-        return _Backend[backend_by_env_var.upper()]
-
-    # Default case.
-    return _Backend.FLASH_ATTN
+    # FlashAttn in NVIDIA GPUs.
+    if selected_backend == _Backend.FLASH_ATTN:
+        if torch.cuda.get_device_capability()[0] < 8:
+            # Volta and Turing NVIDIA GPUs.
+            logger.info(
+                "Cannot use FlashAttention-2 backend for Volta and Turing "
+                "GPUs.")
+            selected_backend = _Backend.XFORMERS
+        elif dtype not in (torch.float16, torch.bfloat16):
+            logger.info(
+                "Cannot use FlashAttention-2 backend for dtype other than "
+                "torch.float16 or torch.bfloat16.")
+            selected_backend = _Backend.XFORMERS
+        elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
+            logger.info(
+                "Cannot use FlashAttention-2 backend for FP8 KV cache.")
+            selected_backend = _Backend.XFORMERS
+        elif block_size % 16 != 0:
+            logger.info(
+                "Cannot use FlashAttention-2 backend for block size not "
+                "divisible by 16.")
+            selected_backend = _Backend.XFORMERS
+        elif sliding_window is not None:
+            logger.info(
+                "Cannot use FlashAttention-2 backend due to sliding window.")
+            selected_backend = _Backend.XFORMERS
+
+    # FlashAttn is valid for the model, checking if the package is installed.
+    if selected_backend == _Backend.FLASH_ATTN:
+        try:
+            import vllm_flash_attn  # noqa: F401
+
+            from aphrodite.attention.backends.flash_attn import (  # noqa: F401
+                FlashAttentionBackend)
+
+            supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
+            if head_size not in supported_sizes:
+                logger.info(
+                    "Cannot use FlashAttention-2 backend for head size "
+                    f"{head_size}")
+                selected_backend = _Backend.XFORMERS
+        except ImportError:
+            logger.info(
+                "Cannot use FlashAttention-2 backend because the "
+                "vllm_flash_attn package is not found. "
+                "`pip install vllm-flash-attn` for better performance.")
+            selected_backend = _Backend.XFORMERS
+
+    return selected_backend