flash_attn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. """Attention layer with Flash and PagedAttention.
  2. NOTE: At the moment, this file includes a lot of duplicated code from
  3. XFormers backend. The duplicated code will be removed once we use flash-attn or
  4. flashinfer for all the attention operations.
  5. """
  6. from dataclasses import dataclass
  7. from typing import Dict, List, Optional, Tuple, Type
  8. import torch
  9. from flash_attn import flash_attn_varlen_func
  10. from aphrodite.attention.backends.abstract import (AttentionBackend,
  11. AttentionImpl,
  12. AttentionMetadata,
  13. AttentionMetadataPerStage)
  14. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  15. PagedAttentionMetadata)
  16. class FlashAttentionBackend(AttentionBackend):
  17. @staticmethod
  18. def get_impl_cls() -> Type["FlashAttentionImpl"]:
  19. return FlashAttentionImpl
  20. @staticmethod
  21. def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
  22. return FlashAttentionMetadata(*args, **kwargs)
  23. @staticmethod
  24. def get_kv_cache_shape(
  25. num_blocks: int,
  26. block_size: int,
  27. num_kv_heads: int,
  28. head_size: int,
  29. ) -> Tuple[int, ...]:
  30. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  31. num_kv_heads, head_size)
  32. @staticmethod
  33. def swap_blocks(
  34. src_kv_cache: torch.Tensor,
  35. dst_kv_cache: torch.Tensor,
  36. src_to_dst: Dict[int, int],
  37. ) -> None:
  38. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  39. @staticmethod
  40. def copy_blocks(
  41. kv_caches: List[torch.Tensor],
  42. src_to_dists: Dict[int, List[int]],
  43. ) -> None:
  44. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  45. @dataclass
  46. class FlashAttentionMetadata(AttentionMetadataPerStage,
  47. PagedAttentionMetadata):
  48. """Metadata for FlashAttentionBackend.
  49. NOTE: Any python object stored here is not updated when it is
  50. cuda-graph replayed. If you have values that need to be changed
  51. dynamically, it should be stored in tensor. The tensor has to be
  52. updated from `CUDAGraphRunner.forward` API.
  53. """
  54. # Currently, input sequences can only contain all prompts
  55. # or all decoding. True if all sequences are prompts.
  56. is_prompt: bool
  57. # (batch_size,). The prompt length per sequence. None if it is a decoding.
  58. prompt_lens: Optional[List[int]]
  59. # prompt_lens stored as a tensor.
  60. prompt_lens_tensor: Optional[torch.Tensor]
  61. # NOTE: Definition of context_len, subquery_len, and seqlen.
  62. # |---------- N-1 iteration --------|
  63. # |---------------- N iteration ---------------------|
  64. # |- tokenA -|......................|-- newTokens ---|
  65. # |---------- context_len ----------|
  66. # |-------------------- seqlen ----------------------|
  67. # |- subquery_len -|
  68. # WARNING: context_len has different definition depending on if it is
  69. # prefill vs decoding. When it is prefill, it doesn't include new tokens.
  70. # When it is for decoding, it includes a new token.
  71. # Maximum subquery length in the batch.
  72. max_subquery_len: Optional[int]
  73. # Maximum prompt length in the batch.
  74. max_prompt_len: Optional[int]
  75. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  76. # the batch, used to index into subquery. E.g., if the subquery length
  77. # is [4, 6], it is [0, 4, 10].
  78. subquery_start_loc: Optional[torch.Tensor]
  79. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  80. # the batch, used to index into sequence. E.g., if the sequence length is
  81. # [4, 6], it is [0, 4, 10].
  82. seq_start_loc: Optional[torch.Tensor]
  83. # Whether or not if cuda graph is enabled.
  84. # Cuda-graph is currently enabled for decoding only.
  85. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  86. use_cuda_graph: bool
  87. class FlashAttentionImpl(AttentionImpl):
  88. """
  89. If the input tensors contain prompt tokens, the layout is as follows:
  90. |<--------------- num_prefill_tokens ----------------->|
  91. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  92. Otherwise, the layout is as follows:
  93. |<----------------- num_decode_tokens ------------------>|
  94. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  95. Generation tokens can contain padding when cuda-graph is used.
  96. Currently, prompt tokens don't contain any padding.
  97. The prompts might have different lengths, while the generation tokens
  98. always have length 1.
  99. If chunked prefill is enabled, prefill tokens and decode tokens can be
  100. batched together in a flattened 1D query.
  101. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  102. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  103. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  104. padding between prefill and decode tokens.
  105. """
  106. def __init__(
  107. self,
  108. num_heads: int,
  109. head_size: int,
  110. scale: float,
  111. num_kv_heads: Optional[int] = None,
  112. alibi_slopes: Optional[List[float]] = None,
  113. sliding_window: Optional[int] = None,
  114. ) -> None:
  115. self.num_heads = num_heads
  116. self.head_size = head_size
  117. self.scale = float(scale)
  118. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  119. self.sliding_window = ((sliding_window, sliding_window)
  120. if sliding_window is not None else (-1, -1))
  121. if alibi_slopes is not None:
  122. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  123. self.alibi_slopes = alibi_slopes
  124. assert self.num_heads % self.num_kv_heads == 0
  125. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  126. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  127. if head_size not in suppored_head_sizes:
  128. raise ValueError(
  129. f"Head size {head_size} is not supported by PagedAttention. "
  130. f"Supported head sizes are: {suppored_head_sizes}.")
  131. def forward(
  132. self,
  133. query: torch.Tensor,
  134. key: torch.Tensor,
  135. value: torch.Tensor,
  136. kv_cache: torch.Tensor,
  137. attn_metadata: AttentionMetadata[FlashAttentionMetadata],
  138. kv_scale: float,
  139. ) -> torch.Tensor:
  140. """Forward pass with FlashAttention and PagedAttention.
  141. Args:
  142. query: shape = [num_tokens, num_heads * head_size]
  143. key: shape = [num_tokens, num_kv_heads * head_size]
  144. value: shape = [num_tokens, num_kv_heads * head_size]
  145. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  146. attn_metadata: Metadata for attention.
  147. Returns:
  148. shape = [num_tokens, num_heads * head_size]
  149. """
  150. num_tokens, hidden_size = query.shape
  151. # Reshape the query, key, and value tensors.
  152. query = query.view(-1, self.num_heads, self.head_size)
  153. key = key.view(-1, self.num_kv_heads, self.head_size)
  154. value = value.view(-1, self.num_kv_heads, self.head_size)
  155. if kv_cache is not None:
  156. key_cache, value_cache = PagedAttention.split_kv_cache(
  157. kv_cache, self.num_kv_heads, self.head_size)
  158. # Reshape the input keys and values and store them in the cache.
  159. # If kv_cache is not provided, the new key and value tensors are
  160. # not cached. This happens during the initial memory profiling run.
  161. PagedAttention.write_to_paged_cache(key, value, key_cache,
  162. value_cache,
  163. attn_metadata.slot_mapping,
  164. attn_metadata.kv_cache_dtype,
  165. kv_scale)
  166. num_prefill_tokens = attn_metadata.num_prefill_tokens
  167. num_decode_tokens = attn_metadata.num_decode_tokens
  168. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  169. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  170. output = torch.empty_like(query)
  171. # Query for decode. KV is not needed because it is already cached.
  172. decode_query = query[num_prefill_tokens:]
  173. # QKV for prefill.
  174. query = query[:num_prefill_tokens]
  175. key = key[:num_prefill_tokens]
  176. value = value[:num_prefill_tokens]
  177. assert query.shape[0] == num_prefill_tokens
  178. assert decode_query.shape[0] == num_decode_tokens
  179. if prefill_meta := attn_metadata.prefill_metadata:
  180. # Prompt run.
  181. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  182. # normal attention
  183. # When block_tables are not filled, it means q and k are the
  184. # prompt, and they have the same length.
  185. out = flash_attn_varlen_func(
  186. q=query,
  187. k=key,
  188. v=value,
  189. cu_seqlens_q=prefill_meta.seq_start_loc,
  190. cu_seqlens_k=prefill_meta.seq_start_loc,
  191. max_seqlen_q=prefill_meta.max_prompt_len,
  192. max_seqlen_k=prefill_meta.max_prompt_len,
  193. softmax_scale=self.scale,
  194. causal=True,
  195. window_size=self.sliding_window,
  196. alibi_slopes=self.alibi_slopes,
  197. )
  198. assert output[:num_prefill_tokens].shape == out.shape
  199. output[:num_prefill_tokens] = out
  200. else:
  201. # prefix-enabled attention
  202. # TODO: this triton kernel has regression issue (broke) to
  203. # deal with different data types between KV and FP8 KV cache,
  204. # to be addressed separately.
  205. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  206. query,
  207. key,
  208. value,
  209. key_cache,
  210. value_cache,
  211. prefill_meta.block_tables,
  212. prefill_meta.subquery_start_loc,
  213. prefill_meta.prompt_lens_tensor,
  214. prefill_meta.context_lens,
  215. prefill_meta.max_subquery_len,
  216. self.alibi_slopes,
  217. )
  218. if decode_meta := attn_metadata.decode_metadata:
  219. # Decoding run.
  220. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  221. decode_query,
  222. key_cache,
  223. value_cache,
  224. decode_meta.block_tables,
  225. decode_meta.context_lens,
  226. decode_meta.max_context_len,
  227. attn_metadata.kv_cache_dtype,
  228. self.num_kv_heads,
  229. self.scale,
  230. self.alibi_slopes,
  231. kv_scale,
  232. )
  233. # Reshape the output tensor.
  234. return output.view(num_tokens, hidden_size)