123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- from dataclasses import dataclass
- from typing import Any, Dict, List, Optional, Set, Tuple, Type
- import flashinfer
- import torch
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
- from vllm_flash_attn import flash_attn_varlen_func
- from aphrodite._C import cache_ops as ops
- from aphrodite.attention.backends.abstract import (AttentionBackend,
- AttentionImpl,
- AttentionMetadata)
- class FlashInferBackend(AttentionBackend):
- @staticmethod
- def get_name() -> str:
- return "flashinfer"
- @staticmethod
- def get_impl_cls() -> Type["FlashInferImpl"]:
- return FlashInferImpl
- @staticmethod
- def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
- return FlashInferMetadata(*args, **kwargs)
- @staticmethod
- def get_kv_cache_shape(
- num_blocks: int,
- block_size: int,
- num_kv_heads: int,
- head_size: int,
- ) -> Tuple[int, ...]:
- return (num_blocks, 2, block_size, num_kv_heads, head_size)
- @staticmethod
- def swap_blocks(
- src_kv_cache: torch.Tensor,
- dst_kv_cache: torch.Tensor,
- src_to_dst: torch.Tensor,
- ) -> None:
- raise NotImplementedError
- @staticmethod
- def copy_blocks(
- kv_caches: List[torch.Tensor],
- src_to_dists: torch.Tensor,
- ) -> None:
- raise NotImplementedError
- @staticmethod
- def get_supported_head_sizes() -> List[int]:
- return [64, 128, 256]
- @dataclass
- class FlashInferMetadata(AttentionMetadata):
- # Maximum sequence length among prefill batch. 0 if there are decoding
- # requests only.
- max_prefill_seq_len: int
- use_cuda_graph: bool = False
- decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
- # Metadata for the prefill stage since we still
- # use flash attention for prefill.
- seq_start_loc: Optional[torch.Tensor] = None
- block_tables: Optional[torch.Tensor] = None
- # Metadata for the decode stage
- # Workspace buffer required by the kernel, the buffer should not
- # be allocated/deacollated by the FalshInfermetadata object.
- workspace_buffer: Optional[torch.Tensor] = None
- # An example for paged_kv_indices, paged_kv_indptr:
- # request 1, page indices [0, 5, 8]
- # request 2, page indices [1, 6, 7]
- # request 3, page indices [3, 4]
- # paged_kv_indices is a concatenation of page indices of all requests:
- # [0, 5, 8, 1, 6, 7, 3, 4]
- # paged_kv_indptr is used to index into paged_kv_indices:
- # [0, 3, 6, 8]
- # The indptr of the paged kv cache, shape: [batch_size + 1]
- paged_kv_indptr: Optional[torch.Tensor] = None
- # The page indices of the paged kv cache
- paged_kv_indices: Optional[torch.Tensor] = None
- # The number of entries in the last page of each request in
- # the paged kv cache, shape: [batch_size]
- paged_kv_last_page_len: Optional[torch.Tensor] = None
- # The number of query/output heads
- num_qo_heads: Optional[int] = None
- # The number of key/value heads
- num_kv_heads: Optional[int] = None
- # The dimension of the attention heads
- head_dim: Optional[int] = None
- # Block size of vllm
- page_size: Optional[int] = None
- # The data type of the paged kv cache
- data_type: torch.dtype = None
- def __post_init__(self):
- # Refer to
- # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
- supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
- if self.head_dim is not None and self.head_dim \
- not in supported_head_sizes:
- raise ValueError(
- f"Only {supported_head_sizes} are supported for head_dim,",
- f"received {self.head_dim}.")
- # When using flashinfer, we are also creating the FlashInferMetadata,
- # which will also call post_init by default, here we want to skip the
- # post_init if it's the prefill phase.
- if self.num_prefills == 0:
- assert self.num_decode_tokens > 0
- self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
- self.workspace_buffer, "NHD")
- self.decode_wrapper.begin_forward(
- self.paged_kv_indptr,
- self.paged_kv_indices,
- self.paged_kv_last_page_len,
- self.num_qo_heads,
- self.num_kv_heads,
- self.head_dim,
- self.page_size,
- # Disable flashinfer's pos encoding and use vllm's rope.
- pos_encoding_mode="NONE",
- data_type=self.data_type)
- def asdict_zerocopy(self,
- skip_fields: Optional[Set[str]] = None
- ) -> Dict[str, Any]:
- if skip_fields is None:
- skip_fields = set()
- # We need to skip the decode_wrapper field since it cannot be
- # broadcasted with nccl when TP is enabled.
- skip_fields.add('decode_wrapper')
- return super().asdict_zerocopy(skip_fields)
- @property
- def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
- # Currently chunked prefill is not supported
- if self.num_decode_tokens == 0:
- assert self.num_prefills > 0
- return self
- return None
- @property
- def decode_metadata(self) -> Optional["FlashInferMetadata"]:
- # Currently chunked prefill is not supported
- if self.num_prefills > 0:
- assert self.num_decode_tokens == 0
- return None
- return self
- class FlashInferImpl(AttentionImpl):
- def __init__(
- self,
- num_heads: int,
- head_size: int,
- scale: float,
- num_kv_heads: int,
- alibi_slopes: Optional[List[float]],
- sliding_window: Optional[int],
- kv_cache_dtype: str,
- blocksparse_params: Optional[Dict[str, Any]] = None,
- ) -> None:
- assert blocksparse_params is None, ValueError(
- "FlashInfer does not support block-sparse attention.")
- self.num_heads = num_heads
- self.head_size = head_size
- self.scale = float(scale)
- self.num_kv_heads = num_kv_heads
- if alibi_slopes is not None:
- alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
- self.alibi_slopes = alibi_slopes
- if sliding_window is not None:
- raise ValueError("Sliding window is not supported in FlashInfer.")
- self.sliding_window = (-1, -1)
- self.kv_cache_dtype = kv_cache_dtype
- assert self.num_heads % self.num_kv_heads == 0
- self.num_queries_per_kv = self.num_heads // self.num_kv_heads
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: Optional[torch.Tensor],
- attn_metadata: FlashInferMetadata,
- kv_scale: float = 1.0,
- ) -> torch.Tensor:
- assert kv_scale == 1.0
- num_tokens, hidden_size = query.shape
- query = query.view(-1, self.num_heads, self.head_size)
- key = key.view(-1, self.num_kv_heads, self.head_size)
- value = value.view(-1, self.num_kv_heads, self.head_size)
- if attn_metadata.num_prefill_tokens > 0:
- assert attn_metadata.num_decode_tokens == 0, (
- "Chunked prefill is not supported with flashinfer yet.")
- if attn_metadata.num_decode_tokens > 0:
- assert attn_metadata.num_prefill_tokens == 0, (
- "Chunked prefill is not supported with flashinfer yet.")
- if kv_cache is not None:
- # Use the same reshape and cache kernel as flash attention.
- ops.reshape_and_cache_flash(
- key,
- value,
- kv_cache[:, 0],
- kv_cache[:, 1],
- attn_metadata.slot_mapping.flatten(),
- self.kv_cache_dtype,
- )
- if prefill_meta := attn_metadata.prefill_metadata:
- assert prefill_meta.block_tables is not None
- if kv_cache is None or prefill_meta.block_tables.numel() == 0:
- output = flash_attn_varlen_func(
- q=query,
- k=key,
- v=value,
- cu_seqlens_q=prefill_meta.seq_start_loc,
- cu_seqlens_k=prefill_meta.seq_start_loc,
- max_seqlen_q=prefill_meta.max_prefill_seq_len,
- max_seqlen_k=prefill_meta.max_prefill_seq_len,
- softmax_scale=self.scale,
- causal=True,
- window_size=self.sliding_window,
- alibi_slopes=self.alibi_slopes,
- )
- else:
- raise NotImplementedError(
- "Prefix caching is not supported with flashinfer yet.")
- else:
- assert attn_metadata.decode_metadata is not None
- assert attn_metadata.decode_metadata.decode_wrapper is not None
- query = query.contiguous(
- ) # Flashinfer requires query to be contiguous
- output = attn_metadata.decode_metadata.decode_wrapper.forward(
- query,
- kv_cache,
- sm_scale=self.scale,
- )
- return output.view(num_tokens, hidden_size)
|