flashinfer.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from dataclasses import dataclass
  2. from typing import Any, Dict, List, Optional, Set, Tuple, Type
  3. import flashinfer
  4. import torch
  5. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  6. from vllm_flash_attn import flash_attn_varlen_func
  7. from aphrodite._C import cache_ops as ops
  8. from aphrodite.attention.backends.abstract import (AttentionBackend,
  9. AttentionImpl,
  10. AttentionMetadata)
  11. class FlashInferBackend(AttentionBackend):
  12. @staticmethod
  13. def get_name() -> str:
  14. return "flashinfer"
  15. @staticmethod
  16. def get_impl_cls() -> Type["FlashInferImpl"]:
  17. return FlashInferImpl
  18. @staticmethod
  19. def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
  20. return FlashInferMetadata(*args, **kwargs)
  21. @staticmethod
  22. def get_kv_cache_shape(
  23. num_blocks: int,
  24. block_size: int,
  25. num_kv_heads: int,
  26. head_size: int,
  27. ) -> Tuple[int, ...]:
  28. return (num_blocks, 2, block_size, num_kv_heads, head_size)
  29. @staticmethod
  30. def swap_blocks(
  31. src_kv_cache: torch.Tensor,
  32. dst_kv_cache: torch.Tensor,
  33. src_to_dst: torch.Tensor,
  34. ) -> None:
  35. raise NotImplementedError
  36. @staticmethod
  37. def copy_blocks(
  38. kv_caches: List[torch.Tensor],
  39. src_to_dists: torch.Tensor,
  40. ) -> None:
  41. raise NotImplementedError
  42. @staticmethod
  43. def get_supported_head_sizes() -> List[int]:
  44. return [64, 128, 256]
  45. @dataclass
  46. class FlashInferMetadata(AttentionMetadata):
  47. # Maximum sequence length among prefill batch. 0 if there are decoding
  48. # requests only.
  49. max_prefill_seq_len: int
  50. use_cuda_graph: bool = False
  51. decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
  52. # Metadata for the prefill stage since we still
  53. # use flash attention for prefill.
  54. seq_start_loc: Optional[torch.Tensor] = None
  55. block_tables: Optional[torch.Tensor] = None
  56. # Metadata for the decode stage
  57. # Workspace buffer required by the kernel, the buffer should not
  58. # be allocated/deacollated by the FalshInfermetadata object.
  59. workspace_buffer: Optional[torch.Tensor] = None
  60. # An example for paged_kv_indices, paged_kv_indptr:
  61. # request 1, page indices [0, 5, 8]
  62. # request 2, page indices [1, 6, 7]
  63. # request 3, page indices [3, 4]
  64. # paged_kv_indices is a concatenation of page indices of all requests:
  65. # [0, 5, 8, 1, 6, 7, 3, 4]
  66. # paged_kv_indptr is used to index into paged_kv_indices:
  67. # [0, 3, 6, 8]
  68. # The indptr of the paged kv cache, shape: [batch_size + 1]
  69. paged_kv_indptr: Optional[torch.Tensor] = None
  70. # The page indices of the paged kv cache
  71. paged_kv_indices: Optional[torch.Tensor] = None
  72. # The number of entries in the last page of each request in
  73. # the paged kv cache, shape: [batch_size]
  74. paged_kv_last_page_len: Optional[torch.Tensor] = None
  75. # The number of query/output heads
  76. num_qo_heads: Optional[int] = None
  77. # The number of key/value heads
  78. num_kv_heads: Optional[int] = None
  79. # The dimension of the attention heads
  80. head_dim: Optional[int] = None
  81. # Block size of vllm
  82. page_size: Optional[int] = None
  83. # The data type of the paged kv cache
  84. data_type: torch.dtype = None
  85. def __post_init__(self):
  86. # Refer to
  87. # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
  88. supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
  89. if self.head_dim is not None and self.head_dim \
  90. not in supported_head_sizes:
  91. raise ValueError(
  92. f"Only {supported_head_sizes} are supported for head_dim,",
  93. f"received {self.head_dim}.")
  94. # When using flashinfer, we are also creating the FlashInferMetadata,
  95. # which will also call post_init by default, here we want to skip the
  96. # post_init if it's the prefill phase.
  97. if self.num_prefills == 0:
  98. assert self.num_decode_tokens > 0
  99. self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
  100. self.workspace_buffer, "NHD")
  101. self.decode_wrapper.begin_forward(
  102. self.paged_kv_indptr,
  103. self.paged_kv_indices,
  104. self.paged_kv_last_page_len,
  105. self.num_qo_heads,
  106. self.num_kv_heads,
  107. self.head_dim,
  108. self.page_size,
  109. # Disable flashinfer's pos encoding and use vllm's rope.
  110. pos_encoding_mode="NONE",
  111. data_type=self.data_type)
  112. def asdict_zerocopy(self,
  113. skip_fields: Optional[Set[str]] = None
  114. ) -> Dict[str, Any]:
  115. if skip_fields is None:
  116. skip_fields = set()
  117. # We need to skip the decode_wrapper field since it cannot be
  118. # broadcasted with nccl when TP is enabled.
  119. skip_fields.add('decode_wrapper')
  120. return super().asdict_zerocopy(skip_fields)
  121. @property
  122. def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
  123. # Currently chunked prefill is not supported
  124. if self.num_decode_tokens == 0:
  125. assert self.num_prefills > 0
  126. return self
  127. return None
  128. @property
  129. def decode_metadata(self) -> Optional["FlashInferMetadata"]:
  130. # Currently chunked prefill is not supported
  131. if self.num_prefills > 0:
  132. assert self.num_decode_tokens == 0
  133. return None
  134. return self
  135. class FlashInferImpl(AttentionImpl):
  136. def __init__(
  137. self,
  138. num_heads: int,
  139. head_size: int,
  140. scale: float,
  141. num_kv_heads: int,
  142. alibi_slopes: Optional[List[float]],
  143. sliding_window: Optional[int],
  144. kv_cache_dtype: str,
  145. blocksparse_params: Optional[Dict[str, Any]] = None,
  146. ) -> None:
  147. assert blocksparse_params is None, ValueError(
  148. "FlashInfer does not support block-sparse attention.")
  149. self.num_heads = num_heads
  150. self.head_size = head_size
  151. self.scale = float(scale)
  152. self.num_kv_heads = num_kv_heads
  153. if alibi_slopes is not None:
  154. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  155. self.alibi_slopes = alibi_slopes
  156. if sliding_window is not None:
  157. raise ValueError("Sliding window is not supported in FlashInfer.")
  158. self.sliding_window = (-1, -1)
  159. self.kv_cache_dtype = kv_cache_dtype
  160. assert self.num_heads % self.num_kv_heads == 0
  161. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  162. def forward(
  163. self,
  164. query: torch.Tensor,
  165. key: torch.Tensor,
  166. value: torch.Tensor,
  167. kv_cache: Optional[torch.Tensor],
  168. attn_metadata: FlashInferMetadata,
  169. kv_scale: float = 1.0,
  170. ) -> torch.Tensor:
  171. assert kv_scale == 1.0
  172. num_tokens, hidden_size = query.shape
  173. query = query.view(-1, self.num_heads, self.head_size)
  174. key = key.view(-1, self.num_kv_heads, self.head_size)
  175. value = value.view(-1, self.num_kv_heads, self.head_size)
  176. if attn_metadata.num_prefill_tokens > 0:
  177. assert attn_metadata.num_decode_tokens == 0, (
  178. "Chunked prefill is not supported with flashinfer yet.")
  179. if attn_metadata.num_decode_tokens > 0:
  180. assert attn_metadata.num_prefill_tokens == 0, (
  181. "Chunked prefill is not supported with flashinfer yet.")
  182. if kv_cache is not None:
  183. # Use the same reshape and cache kernel as flash attention.
  184. ops.reshape_and_cache_flash(
  185. key,
  186. value,
  187. kv_cache[:, 0],
  188. kv_cache[:, 1],
  189. attn_metadata.slot_mapping.flatten(),
  190. self.kv_cache_dtype,
  191. )
  192. if prefill_meta := attn_metadata.prefill_metadata:
  193. assert prefill_meta.block_tables is not None
  194. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  195. output = flash_attn_varlen_func(
  196. q=query,
  197. k=key,
  198. v=value,
  199. cu_seqlens_q=prefill_meta.seq_start_loc,
  200. cu_seqlens_k=prefill_meta.seq_start_loc,
  201. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  202. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  203. softmax_scale=self.scale,
  204. causal=True,
  205. window_size=self.sliding_window,
  206. alibi_slopes=self.alibi_slopes,
  207. )
  208. else:
  209. raise NotImplementedError(
  210. "Prefix caching is not supported with flashinfer yet.")
  211. else:
  212. assert attn_metadata.decode_metadata is not None
  213. assert attn_metadata.decode_metadata.decode_wrapper is not None
  214. query = query.contiguous(
  215. ) # Flashinfer requires query to be contiguous
  216. output = attn_metadata.decode_metadata.decode_wrapper.forward(
  217. query,
  218. kv_cache,
  219. sm_scale=self.scale,
  220. )
  221. return output.view(num_tokens, hidden_size)