Browse Source

core: use flashinfer for FP8 KV when available (#944)

AlpinDale 2 months ago
parent
commit
4ddc14d653
2 changed files with 26 additions and 5 deletions
  1. 22 5
      aphrodite/attention/backends/flashinfer.py
  2. 4 0
      aphrodite/attention/selector.py

+ 22 - 5
aphrodite/attention/backends/flashinfer.py

@@ -86,6 +86,15 @@ class FlashInferBackend(AttentionBackend):
     def get_supported_head_sizes() -> List[int]:
         return [64, 128, 256]
 
+    @staticmethod
+    def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
+        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
+            return torch.float8_e4m3fn
+        elif kv_cache_dtype == "fp8_e5m2":
+            return torch.float8_e5m2
+        else:
+            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
+
 
 class FlashInferState(AttentionState):
     def __init__(self, runner):
@@ -178,8 +187,8 @@ class FlashInferState(AttentionState):
             self._graph_decode_workspace_buffer, _indptr_buffer,
             self._graph_indices_buffer, _last_page_len_buffer, "NHD",
             use_tensor_cores)
-        kv_cache_dtype = get_kv_cache_torch_dtype(
-            self.runner.kv_cache_dtype, self.runner.model_config.dtype)
+        kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
+            self.runner.kv_cache_dtype)
         paged_kv_indptr_tensor_host = torch.arange(0,
                                                    batch_size + 1,
                                                    dtype=torch.int32)
@@ -339,7 +348,7 @@ class FlashInferMetadata(AttentionMetadata):
                 self.page_size,
                 # Disable flashinfer's pos encoding and use Aphrodite's rope.
                 pos_encoding_mode="NONE",
-                data_type=self.data_type)
+            )
 
     def asdict_zerocopy(self,
                         skip_fields: Optional[Set[str]] = None
@@ -365,7 +374,8 @@ class FlashInferMetadata(AttentionMetadata):
     def decode_metadata(self) -> Optional["FlashInferMetadata"]:
         # Currently chunked prefill is not supported
         if self.num_prefills > 0:
-            assert self.num_decode_tokens == 0
+            assert self.num_decode_tokens == 0, (
+                "Chunked prefill is not supported with flashinfer yet.")
             return None
 
         return self
@@ -674,6 +684,11 @@ class FlashInferImpl(AttentionImpl):
                 k_scale,
                 v_scale,
             )
+            # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
+            # to process the cache in fp8
+            torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
+                self.kv_cache_dtype)
+            kv_cache = kv_cache.view(torch_dtype)
 
         query = query.contiguous(
         )  # Flashinfer requires query to be contiguous
@@ -711,5 +726,7 @@ class FlashInferImpl(AttentionImpl):
                 query,
                 kv_cache,
                 sm_scale=self.scale,
-                logits_soft_cap=self.logits_soft_cap)
+                logits_soft_cap=self.logits_soft_cap,
+                k_scale=k_scale,
+                v_scale=v_scale)
         return output.view(num_tokens, hidden_size)

+ 4 - 0
aphrodite/attention/selector.py

@@ -235,6 +235,10 @@ def which_attn_to_use(
         elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
             logger.info(
                 "Cannot use FlashAttention-2 backend for FP8 KV cache.")
+            logger.warning(
+                "Please use FlashInfer backend with FP8 KV Cache for "
+                "better performance by setting the environment "
+                "variable APHRODITE_ATTENTION_BACKEND=FLASHINFER")
             selected_backend = _Backend.XFORMERS
         elif block_size % 16 != 0:
             logger.info(