flash_attn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """Attention layer with Flash and PagedAttention.
  2. NOTE: At the moment, this file includes a lot of duplicated code from
  3. XFormers backend. The duplicated code will be removed once we use flash-attn or
  4. flashinfer for all the attention operations.
  5. """
  6. from dataclasses import dataclass
  7. from typing import Dict, List, Optional, Tuple, Type
  8. import torch
  9. from flash_attn import flash_attn_varlen_func
  10. from aphrodite.attention.backends.abstract import (
  11. AttentionBackend,
  12. AttentionImpl,
  13. AttentionMetadata,
  14. AttentionMetadataPerStage,
  15. )
  16. from aphrodite.attention.ops.paged_attn import (
  17. PagedAttention,
  18. PagedAttentionMetadata,
  19. )
  20. class FlashAttentionBackend(AttentionBackend):
  21. @staticmethod
  22. def get_impl_cls() -> Type["FlashAttentionImpl"]:
  23. return FlashAttentionImpl
  24. @staticmethod
  25. def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
  26. return FlashAttentionMetadata(*args, **kwargs)
  27. @staticmethod
  28. def get_kv_cache_shape(
  29. num_blocks: int,
  30. block_size: int,
  31. num_kv_heads: int,
  32. head_size: int,
  33. ) -> Tuple[int, ...]:
  34. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  35. num_kv_heads, head_size)
  36. @staticmethod
  37. def swap_blocks(
  38. src_kv_cache: torch.Tensor,
  39. dst_kv_cache: torch.Tensor,
  40. src_to_dst: Dict[int, int],
  41. ) -> None:
  42. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  43. @staticmethod
  44. def copy_blocks(
  45. kv_caches: List[torch.Tensor],
  46. src_to_dists: Dict[int, List[int]],
  47. ) -> None:
  48. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  49. @dataclass
  50. class FlashAttentionMetadata(AttentionMetadataPerStage,
  51. PagedAttentionMetadata):
  52. """Metadata for FlashAttentionBackend.
  53. NOTE: Any python object stored here is not updated when it is
  54. cuda-graph replayed. If you have values that need to be changed
  55. dynamically, it should be stored in tensor. The tensor has to be
  56. updated from `CUDAGraphRunner.forward` API.
  57. """
  58. # Currently, input sequences can only contain all prompts
  59. # or all decoding. True if all sequences are prompts.
  60. is_prompt: bool
  61. # (batch_size,). The prompt length per sequence. None if it is a decoding.
  62. prompt_lens: Optional[List[int]]
  63. # prompt_lens stored as a tensor.
  64. prompt_lens_tensor: Optional[torch.Tensor]
  65. # NOTE: Definition of context_len, subquery_len, and seqlen.
  66. # |---------- N-1 iteration --------|
  67. # |---------------- N iteration ---------------------|
  68. # |- tokenA -|......................|-- newTokens ---|
  69. # |---------- context_len ----------|
  70. # |-------------------- seqlen ----------------------|
  71. # |- subquery_len -|
  72. # WARNING: context_len has different definition depending on if it is
  73. # prefill vs decoding. When it is prefill, it doesn't include new tokens.
  74. # When it is for decoding, it includes a new token.
  75. # Maximum subquery length in the batch.
  76. max_subquery_len: Optional[int]
  77. # Maximum prompt length in the batch.
  78. max_prompt_len: Optional[int]
  79. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  80. # the batch, used to index into subquery. E.g., if the subquery length
  81. # is [4, 6], it is [0, 4, 10].
  82. subquery_start_loc: Optional[torch.Tensor]
  83. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  84. # the batch, used to index into sequence. E.g., if the sequence length is
  85. # [4, 6], it is [0, 4, 10].
  86. seq_start_loc: Optional[torch.Tensor]
  87. # Whether or not if cuda graph is enabled.
  88. # Cuda-graph is currently enabled for decoding only.
  89. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  90. use_cuda_graph: bool
  91. class FlashAttentionImpl(AttentionImpl):
  92. """
  93. If the input tensors contain prompt tokens, the layout is as follows:
  94. |<--------------- num_prefill_tokens ----------------->|
  95. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  96. Otherwise, the layout is as follows:
  97. |<----------------- num_decode_tokens ------------------>|
  98. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  99. Generation tokens can contain padding when cuda-graph is used.
  100. Currently, prompt tokens don't contain any padding.
  101. The prompts might have different lengths, while the generation tokens
  102. always have length 1.
  103. If chunked prefill is enabled, prefill tokens and decode tokens can be
  104. batched together in a flattened 1D query.
  105. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  106. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  107. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  108. padding between prefill and decode tokens.
  109. """
  110. def __init__(
  111. self,
  112. num_heads: int,
  113. head_size: int,
  114. scale: float,
  115. num_kv_heads: Optional[int] = None,
  116. alibi_slopes: Optional[List[float]] = None,
  117. sliding_window: Optional[int] = None,
  118. ) -> None:
  119. self.num_heads = num_heads
  120. self.head_size = head_size
  121. self.scale = float(scale)
  122. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  123. self.sliding_window = ((sliding_window, sliding_window)
  124. if sliding_window is not None else (-1, -1))
  125. if alibi_slopes is not None:
  126. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  127. self.alibi_slopes = alibi_slopes
  128. assert self.num_heads % self.num_kv_heads == 0
  129. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  130. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  131. if head_size not in suppored_head_sizes:
  132. raise ValueError(
  133. f"Head size {head_size} is not supported by PagedAttention. "
  134. f"Supported head sizes are: {suppored_head_sizes}.")
  135. def forward(
  136. self,
  137. query: torch.Tensor,
  138. key: torch.Tensor,
  139. value: torch.Tensor,
  140. kv_cache: torch.Tensor,
  141. attn_metadata: AttentionMetadata[FlashAttentionMetadata],
  142. kv_scale: float,
  143. ) -> torch.Tensor:
  144. """Forward pass with FlashAttention and PagedAttention.
  145. Args:
  146. query: shape = [num_tokens, num_heads * head_size]
  147. key: shape = [num_tokens, num_kv_heads * head_size]
  148. value: shape = [num_tokens, num_kv_heads * head_size]
  149. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  150. attn_metadata: Metadata for attention.
  151. Returns:
  152. shape = [num_tokens, num_heads * head_size]
  153. """
  154. num_tokens, hidden_size = query.shape
  155. # Reshape the query, key, and value tensors.
  156. query = query.view(-1, self.num_heads, self.head_size)
  157. key = key.view(-1, self.num_kv_heads, self.head_size)
  158. value = value.view(-1, self.num_kv_heads, self.head_size)
  159. if kv_cache is not None:
  160. key_cache, value_cache = PagedAttention.split_kv_cache(
  161. kv_cache, self.num_kv_heads, self.head_size)
  162. # Reshape the input keys and values and store them in the cache.
  163. # If kv_cache is not provided, the new key and value tensors are
  164. # not cached. This happens during the initial memory profiling run.
  165. PagedAttention.write_to_paged_cache(key, value, key_cache,
  166. value_cache,
  167. attn_metadata.slot_mapping,
  168. attn_metadata.kv_cache_dtype,
  169. kv_scale)
  170. num_prefill_tokens = attn_metadata.num_prefill_tokens
  171. num_decode_tokens = attn_metadata.num_decode_tokens
  172. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  173. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  174. output = torch.empty_like(query)
  175. # Query for decode. KV is not needed because it is already cached.
  176. decode_query = query[num_prefill_tokens:]
  177. # QKV for prefill.
  178. query = query[:num_prefill_tokens]
  179. key = key[:num_prefill_tokens]
  180. value = value[:num_prefill_tokens]
  181. assert query.shape[0] == num_prefill_tokens
  182. assert decode_query.shape[0] == num_decode_tokens
  183. if prefill_meta := attn_metadata.prefill_metadata:
  184. # Prompt run.
  185. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  186. # normal attention
  187. # When block_tables are not filled, it means q and k are the
  188. # prompt, and they have the same length.
  189. out = flash_attn_varlen_func(
  190. q=query,
  191. k=key,
  192. v=value,
  193. cu_seqlens_q=prefill_meta.seq_start_loc,
  194. cu_seqlens_k=prefill_meta.seq_start_loc,
  195. max_seqlen_q=prefill_meta.max_prompt_len,
  196. max_seqlen_k=prefill_meta.max_prompt_len,
  197. softmax_scale=self.scale,
  198. causal=True,
  199. window_size=self.sliding_window,
  200. alibi_slopes=self.alibi_slopes,
  201. )
  202. assert output[:num_prefill_tokens].shape == out.shape
  203. output[:num_prefill_tokens] = out
  204. else:
  205. # prefix-enabled attention
  206. # TODO: this triton kernel has regression issue (broke) to
  207. # deal with different data types between KV and FP8 KV cache,
  208. # to be addressed separately.
  209. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  210. query,
  211. key,
  212. value,
  213. key_cache,
  214. value_cache,
  215. prefill_meta.block_tables,
  216. prefill_meta.subquery_start_loc,
  217. prefill_meta.prompt_lens_tensor,
  218. prefill_meta.context_lens,
  219. prefill_meta.max_subquery_len,
  220. self.alibi_slopes,
  221. )
  222. if decode_meta := attn_metadata.decode_metadata:
  223. # Decoding run.
  224. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  225. decode_query,
  226. key_cache,
  227. value_cache,
  228. decode_meta.block_tables,
  229. decode_meta.context_lens,
  230. decode_meta.max_context_len,
  231. attn_metadata.kv_cache_dtype,
  232. self.num_kv_heads,
  233. self.scale,
  234. self.alibi_slopes,
  235. kv_scale,
  236. )
  237. # Reshape the output tensor.
  238. return output.view(num_tokens, hidden_size)