flash_attn.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. """Attention layer with FlashAttention."""
  2. from dataclasses import dataclass
  3. from typing import List, Optional, Tuple, Type
  4. import torch
  5. from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
  6. from aphrodite._C import cache_ops
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata)
  10. _SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]
  11. class FlashAttentionBackend(AttentionBackend):
  12. @staticmethod
  13. def get_name() -> str:
  14. return "flash-attn"
  15. @staticmethod
  16. def get_impl_cls() -> Type["FlashAttentionImpl"]:
  17. return FlashAttentionImpl
  18. @staticmethod
  19. def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
  20. return FlashAttentionMetadata(*args, **kwargs)
  21. @staticmethod
  22. def get_kv_cache_shape(
  23. num_blocks: int,
  24. block_size: int,
  25. num_kv_heads: int,
  26. head_size: int,
  27. ) -> Tuple[int, ...]:
  28. if block_size % 16 != 0:
  29. raise ValueError("Block size must be a multiple of 16.")
  30. return (2, num_blocks, block_size, num_kv_heads, head_size)
  31. @staticmethod
  32. def swap_blocks(
  33. src_kv_cache: torch.Tensor,
  34. dst_kv_cache: torch.Tensor,
  35. src_to_dst: torch.Tensor,
  36. ) -> None:
  37. src_key_cache = src_kv_cache[0]
  38. dst_key_cache = dst_kv_cache[0]
  39. cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  40. src_value_cache = src_kv_cache[1]
  41. dst_value_cache = dst_kv_cache[1]
  42. cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  43. @staticmethod
  44. def copy_blocks(
  45. kv_caches: List[torch.Tensor],
  46. src_to_dists: torch.Tensor,
  47. ) -> None:
  48. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  49. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  50. cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
  51. @dataclass
  52. class FlashAttentionMetadata(AttentionMetadata):
  53. """Metadata for FlashAttentionBackend.
  54. NOTE: Any python object stored here is not updated when it is
  55. cuda-graph replayed. If you have values that need to be changed
  56. dynamically, it should be stored in tensor. The tensor has to be
  57. updated from `CUDAGraphRunner.forward` API.
  58. """
  59. # (batch_size,). The sequence length per sequence. Sequence length means
  60. # the computed tokens + new tokens None if it is a decoding.
  61. seq_lens: Optional[List[int]]
  62. # seq_lens stored as a tensor.
  63. seq_lens_tensor: Optional[torch.Tensor]
  64. # NOTE: Definition of context_len, query_len, and seq_len.
  65. # |---------- N-1 iteration --------|
  66. # |---------------- N iteration ---------------------|
  67. # |- tokenA -|......................|-- newTokens ---|
  68. # |---------- context_len ----------|
  69. # |-------------------- seq_len ----------------------|
  70. # |-- query_len ---|
  71. # Maximum query length in the batch. None for decoding.
  72. max_query_len: Optional[int]
  73. # Maximum sequence length among prefill batch. 0 if there are decoding
  74. # requests only.
  75. max_prefill_seq_len: int
  76. # Maximum sequence length among decode batch. 0 if there are prefill
  77. # requests only.
  78. max_decode_seq_len: int
  79. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  80. # the batch, used to index into subquery. E.g., if the subquery length
  81. # is [4, 6], it is [0, 4, 10].
  82. query_start_loc: Optional[torch.Tensor]
  83. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  84. # the batch, used to index into sequence. E.g., if the sequence length is
  85. # [4, 6], it is [0, 4, 10].
  86. seq_start_loc: Optional[torch.Tensor]
  87. # (batch_size,) A tensor of context lengths (tokens that are computed
  88. # so far).
  89. context_lens_tensor: Optional[torch.Tensor]
  90. # (batch_size, max_blocks_per_seq).
  91. # Block addresses per sequence. (Seq id -> list of physical block)
  92. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  93. # in the kv cache. Each block can contain up to block_size tokens.
  94. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  95. # captured.
  96. block_tables: Optional[torch.Tensor]
  97. # Whether or not if cuda graph is enabled.
  98. # Cuda-graph is currently enabled for decoding only.
  99. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  100. use_cuda_graph: bool
  101. _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
  102. _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
  103. @property
  104. def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
  105. if self.num_prefills == 0:
  106. return None
  107. if self._cached_prefill_metadata is not None:
  108. return self._cached_prefill_metadata
  109. assert self.seq_lens is not None
  110. assert self.seq_lens_tensor is not None
  111. assert self.query_start_loc is not None
  112. assert self.context_lens_tensor is not None
  113. assert self.block_tables is not None
  114. assert self.seq_start_loc is not None
  115. self._cached_prefill_metadata = FlashAttentionMetadata(
  116. num_prefills=self.num_prefills,
  117. num_prefill_tokens=self.num_prefill_tokens,
  118. num_decode_tokens=0,
  119. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  120. seq_lens=self.seq_lens[:self.num_prefills],
  121. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  122. max_query_len=self.max_query_len,
  123. max_prefill_seq_len=self.max_prefill_seq_len,
  124. max_decode_seq_len=0,
  125. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  126. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  127. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  128. block_tables=self.block_tables[:self.num_prefills],
  129. use_cuda_graph=False,
  130. )
  131. return self._cached_prefill_metadata
  132. @property
  133. def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
  134. if self.num_decode_tokens == 0:
  135. return None
  136. if self._cached_decode_metadata is not None:
  137. return self._cached_decode_metadata
  138. assert self.block_tables is not None
  139. assert self.seq_lens_tensor is not None
  140. self._cached_decode_metadata = FlashAttentionMetadata(
  141. num_prefills=0,
  142. num_prefill_tokens=0,
  143. num_decode_tokens=self.num_decode_tokens,
  144. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  145. seq_lens=None,
  146. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  147. max_query_len=None,
  148. max_prefill_seq_len=0,
  149. max_decode_seq_len=self.max_decode_seq_len,
  150. query_start_loc=None,
  151. seq_start_loc=None,
  152. context_lens_tensor=None,
  153. block_tables=self.block_tables[self.num_prefills:],
  154. use_cuda_graph=self.use_cuda_graph,
  155. )
  156. return self._cached_decode_metadata
  157. class FlashAttentionImpl(AttentionImpl):
  158. """
  159. If the input tensors contain prompt tokens, the layout is as follows:
  160. |<--------------- num_prefill_tokens ----------------->|
  161. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  162. Otherwise, the layout is as follows:
  163. |<----------------- num_decode_tokens ------------------>|
  164. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  165. Generation tokens can contain padding when cuda-graph is used.
  166. Currently, prompt tokens don't contain any padding.
  167. The prompts might have different lengths, while the generation tokens
  168. always have length 1.
  169. If chunked prefill is enabled, prefill tokens and decode tokens can be
  170. batched together in a flattened 1D query.
  171. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  172. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  173. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  174. padding between prefill and decode tokens.
  175. """
  176. def __init__(
  177. self,
  178. num_heads: int,
  179. head_size: int,
  180. scale: float,
  181. num_kv_heads: int,
  182. alibi_slopes: Optional[List[float]],
  183. sliding_window: Optional[int],
  184. kv_cache_dtype: str,
  185. ) -> None:
  186. self.num_heads = num_heads
  187. self.head_size = head_size
  188. self.scale = float(scale)
  189. self.num_kv_heads = num_kv_heads
  190. if alibi_slopes is not None:
  191. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  192. self.alibi_slopes = alibi_slopes
  193. self.sliding_window = ((sliding_window, sliding_window)
  194. if sliding_window is not None else (-1, -1))
  195. self.kv_cache_dtype = kv_cache_dtype
  196. assert self.num_heads % self.num_kv_heads == 0
  197. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  198. if sliding_window is not None:
  199. # NOTE: flash-attn's sliding window does not work with
  200. # paged KV cache.
  201. raise ValueError(
  202. "Sliding window is not supported in FlashAttention.")
  203. if head_size not in _SUPPORTED_HEAD_SIZES:
  204. raise ValueError(
  205. f"Head size {head_size} is not supported by FlashAttention. "
  206. f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
  207. def forward(
  208. self,
  209. query: torch.Tensor,
  210. key: torch.Tensor,
  211. value: torch.Tensor,
  212. kv_cache: torch.Tensor,
  213. attn_metadata: FlashAttentionMetadata,
  214. kv_scale: float = 1.0,
  215. ) -> torch.Tensor:
  216. """Forward pass with FlashAttention.
  217. Args:
  218. query: shape = [num_tokens, num_heads * head_size]
  219. key: shape = [num_tokens, num_kv_heads * head_size]
  220. value: shape = [num_tokens, num_kv_heads * head_size]
  221. kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
  222. attn_metadata: Metadata for attention.
  223. Returns:
  224. shape = [num_tokens, num_heads * head_size]
  225. """
  226. # NOTE: FlashAttention does not support FP8 KV cache.
  227. assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
  228. num_tokens, hidden_size = query.shape
  229. # Reshape the query, key, and value tensors.
  230. query = query.view(-1, self.num_heads, self.head_size)
  231. key = key.view(-1, self.num_kv_heads, self.head_size)
  232. value = value.view(-1, self.num_kv_heads, self.head_size)
  233. if kv_cache is not None:
  234. key_cache = kv_cache[0]
  235. value_cache = kv_cache[1]
  236. # Reshape the input keys and values and store them in the cache.
  237. # If kv_cache is not provided, the new key and value tensors are
  238. # not cached. This happens during the initial memory profiling run.
  239. cache_ops.reshape_and_cache_flash(
  240. key,
  241. value,
  242. key_cache,
  243. value_cache,
  244. attn_metadata.slot_mapping.flatten(),
  245. self.kv_cache_dtype,
  246. )
  247. num_prefill_tokens = attn_metadata.num_prefill_tokens
  248. num_decode_tokens = attn_metadata.num_decode_tokens
  249. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  250. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  251. output = torch.empty_like(query)
  252. # Query for decode. KV is not needed because it is already cached.
  253. decode_query = query[num_prefill_tokens:]
  254. # QKV for prefill.
  255. query = query[:num_prefill_tokens]
  256. key = key[:num_prefill_tokens]
  257. value = value[:num_prefill_tokens]
  258. assert query.shape[0] == num_prefill_tokens
  259. assert decode_query.shape[0] == num_decode_tokens
  260. if prefill_meta := attn_metadata.prefill_metadata:
  261. # Prompt run.
  262. if (kv_cache is None or prefill_meta.block_tables is None
  263. or prefill_meta.block_tables.numel() == 0):
  264. # normal attention
  265. # When block_tables are not filled, it means q and k are the
  266. # prompt, and they have the same length.
  267. out = flash_attn_varlen_func(
  268. q=query,
  269. k=key,
  270. v=value,
  271. cu_seqlens_q=prefill_meta.seq_start_loc,
  272. cu_seqlens_k=prefill_meta.seq_start_loc,
  273. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  274. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  275. softmax_scale=self.scale,
  276. causal=True,
  277. window_size=self.sliding_window,
  278. alibi_slopes=self.alibi_slopes,
  279. )
  280. assert output[:num_prefill_tokens].shape == out.shape
  281. output[:num_prefill_tokens] = out
  282. else:
  283. # prefix-enabled attention
  284. assert prefill_meta.seq_lens is not None
  285. max_seq_len = max(prefill_meta.seq_lens)
  286. output[:num_prefill_tokens] = flash_attn_varlen_func(
  287. q=query,
  288. k=key_cache,
  289. v=value_cache,
  290. cu_seqlens_q=prefill_meta.query_start_loc,
  291. max_seqlen_q=prefill_meta.max_query_len,
  292. cu_seqlens_k=prefill_meta.seq_start_loc,
  293. max_seqlen_k=max_seq_len,
  294. softmax_scale=self.scale,
  295. causal=True,
  296. alibi_slopes=self.alibi_slopes,
  297. block_table=prefill_meta.block_tables,
  298. )
  299. if decode_meta := attn_metadata.decode_metadata:
  300. # Decoding run.
  301. output[num_prefill_tokens:] = flash_attn_with_kvcache(
  302. decode_query.unsqueeze(1),
  303. key_cache,
  304. value_cache,
  305. block_table=decode_meta.block_tables,
  306. cache_seqlens=decode_meta.seq_lens_tensor,
  307. softmax_scale=self.scale,
  308. causal=True,
  309. alibi_slopes=self.alibi_slopes,
  310. ).squeeze(1)
  311. # Reshape the output tensor.
  312. return output.view(num_tokens, hidden_size)