flash_attn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. """Attention layer with FlashAttention."""
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
  6. from aphrodite._C import cache_ops
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata)
  10. class FlashAttentionBackend(AttentionBackend):
  11. @staticmethod
  12. def get_supported_head_sizes() -> List[int]:
  13. return [32, 64, 96, 128, 160, 192, 224, 256]
  14. @staticmethod
  15. def get_name() -> str:
  16. return "flash-attn"
  17. @staticmethod
  18. def get_impl_cls() -> Type["FlashAttentionImpl"]:
  19. return FlashAttentionImpl
  20. @staticmethod
  21. def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
  22. return FlashAttentionMetadata(*args, **kwargs)
  23. @staticmethod
  24. def get_kv_cache_shape(
  25. num_blocks: int,
  26. block_size: int,
  27. num_kv_heads: int,
  28. head_size: int,
  29. ) -> Tuple[int, ...]:
  30. if block_size % 16 != 0:
  31. raise ValueError("Block size must be a multiple of 16.")
  32. return (2, num_blocks, block_size, num_kv_heads, head_size)
  33. @staticmethod
  34. def swap_blocks(
  35. src_kv_cache: torch.Tensor,
  36. dst_kv_cache: torch.Tensor,
  37. src_to_dst: torch.Tensor,
  38. ) -> None:
  39. src_key_cache = src_kv_cache[0]
  40. dst_key_cache = dst_kv_cache[0]
  41. cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  42. src_value_cache = src_kv_cache[1]
  43. dst_value_cache = dst_kv_cache[1]
  44. cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  45. @staticmethod
  46. def copy_blocks(
  47. kv_caches: List[torch.Tensor],
  48. src_to_dists: torch.Tensor,
  49. ) -> None:
  50. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  51. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  52. cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
  53. @dataclass
  54. class FlashAttentionMetadata(AttentionMetadata):
  55. """Metadata for FlashAttentionBackend.
  56. NOTE: Any python object stored here is not updated when it is
  57. cuda-graph replayed. If you have values that need to be changed
  58. dynamically, it should be stored in tensor. The tensor has to be
  59. updated from `CUDAGraphRunner.forward` API.
  60. """
  61. # (batch_size,). The sequence length per sequence. Sequence length means
  62. # the computed tokens + new tokens None if it is a decoding.
  63. seq_lens: Optional[List[int]]
  64. # seq_lens stored as a tensor.
  65. seq_lens_tensor: Optional[torch.Tensor]
  66. # NOTE: Definition of context_len, query_len, and seq_len.
  67. # |---------- N-1 iteration --------|
  68. # |---------------- N iteration ---------------------|
  69. # |- tokenA -|......................|-- newTokens ---|
  70. # |---------- context_len ----------|
  71. # |-------------------- seq_len ----------------------|
  72. # |-- query_len ---|
  73. # Maximum query length in the batch. None for decoding.
  74. max_query_len: Optional[int]
  75. # Maximum sequence length among prefill batch. 0 if there are decoding
  76. # requests only.
  77. max_prefill_seq_len: int
  78. # Maximum sequence length among decode batch. 0 if there are prefill
  79. # requests only.
  80. max_decode_seq_len: int
  81. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  82. # the batch, used to index into subquery. E.g., if the subquery length
  83. # is [4, 6], it is [0, 4, 10].
  84. query_start_loc: Optional[torch.Tensor]
  85. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  86. # the batch, used to index into sequence. E.g., if the sequence length is
  87. # [4, 6], it is [0, 4, 10].
  88. seq_start_loc: Optional[torch.Tensor]
  89. # (batch_size,) A tensor of context lengths (tokens that are computed
  90. # so far).
  91. context_lens_tensor: Optional[torch.Tensor]
  92. # (batch_size, max_blocks_per_seq).
  93. # Block addresses per sequence. (Seq id -> list of physical block)
  94. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  95. # in the kv cache. Each block can contain up to block_size tokens.
  96. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  97. # captured.
  98. block_tables: Optional[torch.Tensor]
  99. # Whether or not if cuda graph is enabled.
  100. # Cuda-graph is currently enabled for decoding only.
  101. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  102. use_cuda_graph: bool
  103. _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
  104. _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
  105. @property
  106. def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
  107. if self.num_prefills == 0:
  108. return None
  109. if self._cached_prefill_metadata is not None:
  110. return self._cached_prefill_metadata
  111. assert self.seq_lens is not None
  112. assert self.seq_lens_tensor is not None
  113. assert self.query_start_loc is not None
  114. assert self.context_lens_tensor is not None
  115. assert self.block_tables is not None
  116. assert self.seq_start_loc is not None
  117. self._cached_prefill_metadata = FlashAttentionMetadata(
  118. num_prefills=self.num_prefills,
  119. num_prefill_tokens=self.num_prefill_tokens,
  120. num_decode_tokens=0,
  121. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  122. seq_lens=self.seq_lens[:self.num_prefills],
  123. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  124. max_query_len=self.max_query_len,
  125. max_prefill_seq_len=self.max_prefill_seq_len,
  126. max_decode_seq_len=0,
  127. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  128. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  129. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  130. block_tables=self.block_tables[:self.num_prefills],
  131. use_cuda_graph=False,
  132. )
  133. return self._cached_prefill_metadata
  134. @property
  135. def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
  136. if self.num_decode_tokens == 0:
  137. return None
  138. if self._cached_decode_metadata is not None:
  139. return self._cached_decode_metadata
  140. assert self.block_tables is not None
  141. assert self.seq_lens_tensor is not None
  142. self._cached_decode_metadata = FlashAttentionMetadata(
  143. num_prefills=0,
  144. num_prefill_tokens=0,
  145. num_decode_tokens=self.num_decode_tokens,
  146. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  147. seq_lens=None,
  148. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  149. max_query_len=None,
  150. max_prefill_seq_len=0,
  151. max_decode_seq_len=self.max_decode_seq_len,
  152. query_start_loc=None,
  153. seq_start_loc=None,
  154. context_lens_tensor=None,
  155. block_tables=self.block_tables[self.num_prefills:],
  156. use_cuda_graph=self.use_cuda_graph,
  157. )
  158. return self._cached_decode_metadata
  159. class FlashAttentionImpl(AttentionImpl):
  160. """
  161. If the input tensors contain prompt tokens, the layout is as follows:
  162. |<--------------- num_prefill_tokens ----------------->|
  163. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  164. Otherwise, the layout is as follows:
  165. |<----------------- num_decode_tokens ------------------>|
  166. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  167. Generation tokens can contain padding when cuda-graph is used.
  168. Currently, prompt tokens don't contain any padding.
  169. The prompts might have different lengths, while the generation tokens
  170. always have length 1.
  171. If chunked prefill is enabled, prefill tokens and decode tokens can be
  172. batched together in a flattened 1D query.
  173. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  174. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  175. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  176. padding between prefill and decode tokens.
  177. """
  178. def __init__(
  179. self,
  180. num_heads: int,
  181. head_size: int,
  182. scale: float,
  183. num_kv_heads: int,
  184. alibi_slopes: Optional[List[float]],
  185. sliding_window: Optional[int],
  186. kv_cache_dtype: str,
  187. blocksparse_params: Optional[Dict[str, Any]] = None,
  188. ) -> None:
  189. assert blocksparse_params is None, ValueError(
  190. "FlashAttention does not support block-sparse attention.")
  191. self.num_heads = num_heads
  192. self.head_size = head_size
  193. self.scale = float(scale)
  194. self.num_kv_heads = num_kv_heads
  195. if alibi_slopes is not None:
  196. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  197. self.alibi_slopes = alibi_slopes
  198. self.sliding_window = ((sliding_window, sliding_window)
  199. if sliding_window is not None else (-1, -1))
  200. self.kv_cache_dtype = kv_cache_dtype
  201. assert self.num_heads % self.num_kv_heads == 0
  202. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  203. if sliding_window is not None:
  204. # NOTE: flash-attn's sliding window does not work with
  205. # paged KV cache.
  206. raise ValueError(
  207. "Sliding window is not supported in FlashAttention.")
  208. support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
  209. if head_size not in support_head_sizes:
  210. raise ValueError(
  211. f"Head size {head_size} is not supported by FlashAttention. "
  212. f"Supported head sizes are: {support_head_sizes}.")
  213. def forward(
  214. self,
  215. query: torch.Tensor,
  216. key: torch.Tensor,
  217. value: torch.Tensor,
  218. kv_cache: torch.Tensor,
  219. attn_metadata: FlashAttentionMetadata,
  220. kv_scale: float = 1.0,
  221. ) -> torch.Tensor:
  222. """Forward pass with FlashAttention.
  223. Args:
  224. query: shape = [num_tokens, num_heads * head_size]
  225. key: shape = [num_tokens, num_kv_heads * head_size]
  226. value: shape = [num_tokens, num_kv_heads * head_size]
  227. kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
  228. attn_metadata: Metadata for attention.
  229. Returns:
  230. shape = [num_tokens, num_heads * head_size]
  231. """
  232. # NOTE: FlashAttention does not support FP8 KV cache.
  233. assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
  234. num_tokens, hidden_size = query.shape
  235. # Reshape the query, key, and value tensors.
  236. query = query.view(-1, self.num_heads, self.head_size)
  237. key = key.view(-1, self.num_kv_heads, self.head_size)
  238. value = value.view(-1, self.num_kv_heads, self.head_size)
  239. if kv_cache is not None:
  240. key_cache = kv_cache[0]
  241. value_cache = kv_cache[1]
  242. # Reshape the input keys and values and store them in the cache.
  243. # If kv_cache is not provided, the new key and value tensors are
  244. # not cached. This happens during the initial memory profiling run.
  245. cache_ops.reshape_and_cache_flash(
  246. key,
  247. value,
  248. key_cache,
  249. value_cache,
  250. attn_metadata.slot_mapping.flatten(),
  251. self.kv_cache_dtype,
  252. )
  253. num_prefill_tokens = attn_metadata.num_prefill_tokens
  254. num_decode_tokens = attn_metadata.num_decode_tokens
  255. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  256. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  257. output = torch.empty_like(query)
  258. # Query for decode. KV is not needed because it is already cached.
  259. decode_query = query[num_prefill_tokens:]
  260. # QKV for prefill.
  261. query = query[:num_prefill_tokens]
  262. key = key[:num_prefill_tokens]
  263. value = value[:num_prefill_tokens]
  264. assert query.shape[0] == num_prefill_tokens
  265. assert decode_query.shape[0] == num_decode_tokens
  266. if prefill_meta := attn_metadata.prefill_metadata:
  267. # Prompt run.
  268. if (kv_cache is None or prefill_meta.block_tables is None
  269. or prefill_meta.block_tables.numel() == 0):
  270. # normal attention
  271. # When block_tables are not filled, it means q and k are the
  272. # prompt, and they have the same length.
  273. out = flash_attn_varlen_func(
  274. q=query,
  275. k=key,
  276. v=value,
  277. cu_seqlens_q=prefill_meta.seq_start_loc,
  278. cu_seqlens_k=prefill_meta.seq_start_loc,
  279. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  280. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  281. softmax_scale=self.scale,
  282. causal=True,
  283. window_size=self.sliding_window,
  284. alibi_slopes=self.alibi_slopes,
  285. )
  286. assert output[:num_prefill_tokens].shape == out.shape
  287. output[:num_prefill_tokens] = out
  288. else:
  289. # prefix-enabled attention
  290. assert prefill_meta.seq_lens is not None
  291. max_seq_len = max(prefill_meta.seq_lens)
  292. output[:num_prefill_tokens] = flash_attn_varlen_func(
  293. q=query,
  294. k=key_cache,
  295. v=value_cache,
  296. cu_seqlens_q=prefill_meta.query_start_loc,
  297. max_seqlen_q=prefill_meta.max_query_len,
  298. cu_seqlens_k=prefill_meta.seq_start_loc,
  299. max_seqlen_k=max_seq_len,
  300. softmax_scale=self.scale,
  301. causal=True,
  302. alibi_slopes=self.alibi_slopes,
  303. block_table=prefill_meta.block_tables,
  304. )
  305. if decode_meta := attn_metadata.decode_metadata:
  306. # Decoding run.
  307. output[num_prefill_tokens:] = flash_attn_with_kvcache(
  308. decode_query.unsqueeze(1),
  309. key_cache,
  310. value_cache,
  311. block_table=decode_meta.block_tables,
  312. cache_seqlens=decode_meta.seq_lens_tensor,
  313. softmax_scale=self.scale,
  314. causal=True,
  315. alibi_slopes=self.alibi_slopes,
  316. ).squeeze(1)
  317. # Reshape the output tensor.
  318. return output.view(num_tokens, hidden_size)