Procházet zdrojové kódy

fix head_size check for flash attention backend

AlpinDale před 7 měsíci
rodič
revize
b8b63eb5ca

+ 7 - 4
aphrodite/attention/backends/flash_attn.py

@@ -10,11 +10,13 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionImpl,
                                                    AttentionMetadata)
                                                    AttentionMetadata)
 
 
-_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]
-
 
 
 class FlashAttentionBackend(AttentionBackend):
 class FlashAttentionBackend(AttentionBackend):
 
 
+    @staticmethod
+    def get_supported_head_sizes() -> List[int]:
+        return [32, 64, 96, 128, 160, 192, 224, 256]
+
     @staticmethod
     @staticmethod
     def get_name() -> str:
     def get_name() -> str:
         return "flash-attn"
         return "flash-attn"
@@ -238,10 +240,11 @@ class FlashAttentionImpl(AttentionImpl):
             # paged KV cache.
             # paged KV cache.
             raise ValueError(
             raise ValueError(
                 "Sliding window is not supported in FlashAttention.")
                 "Sliding window is not supported in FlashAttention.")
-        if head_size not in _SUPPORTED_HEAD_SIZES:
+        support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
+        if head_size not in support_head_sizes:
             raise ValueError(
             raise ValueError(
                 f"Head size {head_size} is not supported by FlashAttention. "
                 f"Head size {head_size} is not supported by FlashAttention. "
-                f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
+                f"Supported head sizes are: {support_head_sizes}.")
 
 
     def forward(
     def forward(
         self,
         self,

+ 10 - 3
aphrodite/attention/selector.py

@@ -33,11 +33,18 @@ def get_attn_backend(
                                  sliding_window, dtype, kv_cache_dtype,
                                  sliding_window, dtype, kv_cache_dtype,
                                  block_size)
                                  block_size)
     if backend == _Backend.FLASH_ATTN:
     if backend == _Backend.FLASH_ATTN:
-        logger.info("Using FlashAttention backend.")
         from aphrodite.attention.backends.flash_attn import \
         from aphrodite.attention.backends.flash_attn import \
             FlashAttentionBackend  # noqa: F401
             FlashAttentionBackend  # noqa: F401
-        return FlashAttentionBackend
-    elif backend == _Backend.XFORMERS:
+        # We check it here not in _which_attn_to_use because we cannot know
+        # the head size until we import FlashAttentionBackend.
+        supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
+        if head_size in supported_head_sizes:
+            logger.info("Using FlashAttention-2 backend.")
+            return FlashAttentionBackend
+        logger.info("Cannot use FlashAttention-2 backend for head "
+                    f"size {head_size}. Using XFormers backend instead.")
+        backend = _Backend.XFORMERS
+    if backend == _Backend.XFORMERS:
         logger.info("Using XFormers backend.")
         logger.info("Using XFormers backend.")
         from aphrodite.attention.backends.xformers import \
         from aphrodite.attention.backends.xformers import \
             XFormersBackend  # noqa: F401
             XFormersBackend  # noqa: F401