|
@@ -86,6 +86,15 @@ class FlashInferBackend(AttentionBackend):
|
|
|
def get_supported_head_sizes() -> List[int]:
|
|
|
return [64, 128, 256]
|
|
|
|
|
|
+ @staticmethod
|
|
|
+ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
|
|
+ if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
|
|
+ return torch.float8_e4m3fn
|
|
|
+ elif kv_cache_dtype == "fp8_e5m2":
|
|
|
+ return torch.float8_e5m2
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
|
|
+
|
|
|
|
|
|
class FlashInferState(AttentionState):
|
|
|
def __init__(self, runner):
|
|
@@ -178,8 +187,8 @@ class FlashInferState(AttentionState):
|
|
|
self._graph_decode_workspace_buffer, _indptr_buffer,
|
|
|
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
|
|
use_tensor_cores)
|
|
|
- kv_cache_dtype = get_kv_cache_torch_dtype(
|
|
|
- self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
|
|
+ kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
|
|
+ self.runner.kv_cache_dtype)
|
|
|
paged_kv_indptr_tensor_host = torch.arange(0,
|
|
|
batch_size + 1,
|
|
|
dtype=torch.int32)
|
|
@@ -339,7 +348,7 @@ class FlashInferMetadata(AttentionMetadata):
|
|
|
self.page_size,
|
|
|
# Disable flashinfer's pos encoding and use Aphrodite's rope.
|
|
|
pos_encoding_mode="NONE",
|
|
|
- data_type=self.data_type)
|
|
|
+ )
|
|
|
|
|
|
def asdict_zerocopy(self,
|
|
|
skip_fields: Optional[Set[str]] = None
|
|
@@ -365,7 +374,8 @@ class FlashInferMetadata(AttentionMetadata):
|
|
|
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
|
|
# Currently chunked prefill is not supported
|
|
|
if self.num_prefills > 0:
|
|
|
- assert self.num_decode_tokens == 0
|
|
|
+ assert self.num_decode_tokens == 0, (
|
|
|
+ "Chunked prefill is not supported with flashinfer yet.")
|
|
|
return None
|
|
|
|
|
|
return self
|
|
@@ -674,6 +684,11 @@ class FlashInferImpl(AttentionImpl):
|
|
|
k_scale,
|
|
|
v_scale,
|
|
|
)
|
|
|
+ # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
|
|
+ # to process the cache in fp8
|
|
|
+ torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
|
|
|
+ self.kv_cache_dtype)
|
|
|
+ kv_cache = kv_cache.view(torch_dtype)
|
|
|
|
|
|
query = query.contiguous(
|
|
|
) # Flashinfer requires query to be contiguous
|
|
@@ -711,5 +726,7 @@ class FlashInferImpl(AttentionImpl):
|
|
|
query,
|
|
|
kv_cache,
|
|
|
sm_scale=self.scale,
|
|
|
- logits_soft_cap=self.logits_soft_cap)
|
|
|
+ logits_soft_cap=self.logits_soft_cap,
|
|
|
+ k_scale=k_scale,
|
|
|
+ v_scale=v_scale)
|
|
|
return output.view(num_tokens, hidden_size)
|