flash_attn.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. """Attention layer with FlashAttention."""
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.attention.backends.abstract import (AttentionBackend,
  7. AttentionImpl,
  8. AttentionMetadata,
  9. AttentionMetadataBuilder,
  10. AttentionType)
  11. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  12. CommonAttentionState,
  13. compute_slot_mapping,
  14. compute_slot_mapping_start_idx,
  15. is_block_tables_empty)
  16. from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
  17. if TYPE_CHECKING:
  18. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  19. from aphrodite_flash_attn import (
  20. flash_attn_varlen_func as _flash_attn_varlen_func)
  21. from aphrodite_flash_attn import (
  22. flash_attn_with_kvcache as _flash_attn_with_kvcache)
  23. @torch.library.custom_op("aphrodite::flash_attn_varlen_func", mutates_args=[])
  24. def flash_attn_varlen_func(
  25. q: torch.Tensor,
  26. k: torch.Tensor,
  27. v: torch.Tensor,
  28. cu_seqlens_q: torch.Tensor,
  29. cu_seqlens_k: torch.Tensor,
  30. max_seqlen_q: int,
  31. max_seqlen_k: int,
  32. softmax_scale: Optional[float] = None,
  33. causal: bool = False,
  34. window_size: Optional[List[int]] = None,
  35. softcap: float = 0.0,
  36. alibi_slopes: Optional[torch.Tensor] = None,
  37. block_table: Optional[torch.Tensor] = None,
  38. ) -> torch.Tensor:
  39. # custom op does not support tuple input
  40. real_window_size: Tuple[int, int]
  41. if window_size is None:
  42. real_window_size = (-1, -1)
  43. else:
  44. assert len(window_size) == 2
  45. real_window_size = (window_size[0], window_size[1])
  46. return _flash_attn_varlen_func(
  47. q=q,
  48. k=k,
  49. v=v,
  50. cu_seqlens_q=cu_seqlens_q,
  51. cu_seqlens_k=cu_seqlens_k,
  52. max_seqlen_q=max_seqlen_q,
  53. max_seqlen_k=max_seqlen_k,
  54. softmax_scale=softmax_scale,
  55. causal=causal,
  56. window_size=real_window_size,
  57. softcap=softcap,
  58. alibi_slopes=alibi_slopes,
  59. block_table=block_table,
  60. )
  61. @flash_attn_varlen_func.register_fake # type: ignore
  62. def _(
  63. q: torch.Tensor,
  64. k: torch.Tensor,
  65. v: torch.Tensor,
  66. cu_seqlens_q: torch.Tensor,
  67. cu_seqlens_k: torch.Tensor,
  68. max_seqlen_q: int,
  69. max_seqlen_k: int,
  70. softmax_scale: Optional[float] = None,
  71. causal: bool = False,
  72. window_size: Optional[List[int]] = None,
  73. softcap: float = 0.0,
  74. alibi_slopes: Optional[torch.Tensor] = None,
  75. block_table: Optional[torch.Tensor] = None,
  76. ) -> torch.Tensor:
  77. return torch.empty_like(q)
  78. @torch.library.custom_op("aphrodite::flash_attn_with_kvcache", mutates_args=[])
  79. def flash_attn_with_kvcache(
  80. decode_query: torch.Tensor,
  81. key_cache: torch.Tensor,
  82. value_cache: torch.Tensor,
  83. cache_seqlens: Optional[torch.Tensor] = None,
  84. block_table: Optional[torch.Tensor] = None,
  85. softmax_scale: Optional[float] = None,
  86. causal: bool = False,
  87. alibi_slopes: Optional[torch.Tensor] = None,
  88. softcap: float = 0.0,
  89. ) -> torch.Tensor:
  90. return _flash_attn_with_kvcache(
  91. decode_query,
  92. key_cache,
  93. value_cache,
  94. cache_seqlens=cache_seqlens,
  95. block_table=block_table,
  96. softmax_scale=softmax_scale,
  97. causal=causal,
  98. alibi_slopes=alibi_slopes,
  99. softcap=softcap,
  100. )
  101. @flash_attn_with_kvcache.register_fake # type: ignore
  102. def _(
  103. decode_query: torch.Tensor,
  104. key_cache: torch.Tensor,
  105. value_cache: torch.Tensor,
  106. cache_seqlens: Optional[torch.Tensor] = None,
  107. block_table: Optional[torch.Tensor] = None,
  108. softmax_scale: Optional[float] = None,
  109. causal: bool = False,
  110. alibi_slopes: Optional[torch.Tensor] = None,
  111. softcap: float = 0.0,
  112. ) -> torch.Tensor:
  113. return torch.empty_like(decode_query)
  114. class FlashAttentionBackend(AttentionBackend):
  115. @staticmethod
  116. def get_supported_head_sizes() -> List[int]:
  117. return [32, 64, 96, 128, 160, 192, 224, 256]
  118. @staticmethod
  119. def get_name() -> str:
  120. return "flash-attn"
  121. @staticmethod
  122. def get_impl_cls() -> Type["FlashAttentionImpl"]:
  123. return FlashAttentionImpl
  124. @staticmethod
  125. def get_metadata_cls() -> Type["AttentionMetadata"]:
  126. return FlashAttentionMetadata
  127. @staticmethod
  128. def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
  129. return FlashAttentionMetadataBuilder
  130. @staticmethod
  131. def get_state_cls() -> Type["CommonAttentionState"]:
  132. return CommonAttentionState
  133. @staticmethod
  134. def get_kv_cache_shape(
  135. num_blocks: int,
  136. block_size: int,
  137. num_kv_heads: int,
  138. head_size: int,
  139. ) -> Tuple[int, ...]:
  140. if block_size % 16 != 0:
  141. raise ValueError("Block size must be a multiple of 16.")
  142. return (2, num_blocks, block_size, num_kv_heads, head_size)
  143. @staticmethod
  144. def swap_blocks(
  145. src_kv_cache: torch.Tensor,
  146. dst_kv_cache: torch.Tensor,
  147. src_to_dst: torch.Tensor,
  148. ) -> None:
  149. src_key_cache = src_kv_cache[0]
  150. dst_key_cache = dst_kv_cache[0]
  151. ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  152. src_value_cache = src_kv_cache[1]
  153. dst_value_cache = dst_kv_cache[1]
  154. ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  155. @staticmethod
  156. def copy_blocks(
  157. kv_caches: List[torch.Tensor],
  158. src_to_dists: torch.Tensor,
  159. ) -> None:
  160. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  161. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  162. ops.copy_blocks(key_caches, value_caches, src_to_dists)
  163. @dataclass
  164. class FlashAttentionMetadata(AttentionMetadata):
  165. """Metadata for FlashAttentionBackend.
  166. NOTE: Any python object stored here is not updated when it is
  167. cuda-graph replayed. If you have values that need to be changed
  168. dynamically, it should be stored in tensor. The tensor has to be
  169. updated from `CUDAGraphRunner.forward` API.
  170. """
  171. # (batch_size,). The sequence length per sequence. Sequence length means
  172. # the computed tokens + new tokens None if it is a decoding.
  173. seq_lens: Optional[List[int]]
  174. # seq_lens stored as a tensor.
  175. seq_lens_tensor: Optional[torch.Tensor]
  176. # NOTE: Definition of context_len, query_len, and seq_len.
  177. # |---------- N-1 iteration --------|
  178. # |---------------- N iteration ---------------------|
  179. # |- tokenA -|......................|-- newTokens ---|
  180. # |---------- context_len ----------|
  181. # |-------------------- seq_len ----------------------|
  182. # |-- query_len ---|
  183. # Maximum query length in the batch. None for decoding.
  184. max_query_len: Optional[int]
  185. # Maximum sequence length among prefill batch. 0 if there are decoding
  186. # requests only.
  187. max_prefill_seq_len: int
  188. # Maximum sequence length among decode batch. 0 if there are prefill
  189. # requests only.
  190. max_decode_seq_len: int
  191. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  192. # the batch, used to index into subquery. E.g., if the subquery length
  193. # is [4, 6], it is [0, 4, 10].
  194. query_start_loc: Optional[torch.Tensor]
  195. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  196. # the batch, used to index into sequence. E.g., if the sequence length is
  197. # [4, 6], it is [0, 4, 10].
  198. seq_start_loc: Optional[torch.Tensor]
  199. # (batch_size,) A tensor of context lengths (tokens that are computed
  200. # so far).
  201. context_lens_tensor: Optional[torch.Tensor]
  202. # (batch_size, max_blocks_per_seq).
  203. # Block addresses per sequence. (Seq id -> list of physical block)
  204. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  205. # in the kv cache. Each block can contain up to block_size tokens.
  206. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  207. # captured.
  208. block_tables: Optional[torch.Tensor]
  209. # Whether or not if cuda graph is enabled.
  210. # Cuda-graph is currently enabled for decoding only.
  211. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  212. use_cuda_graph: bool
  213. _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
  214. _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
  215. @property
  216. def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
  217. if self.num_prefills == 0:
  218. return None
  219. if self._cached_prefill_metadata is not None:
  220. return self._cached_prefill_metadata
  221. assert self.seq_lens is not None
  222. assert self.seq_lens_tensor is not None
  223. assert self.query_start_loc is not None
  224. assert self.context_lens_tensor is not None
  225. assert self.block_tables is not None
  226. assert self.seq_start_loc is not None
  227. self._cached_prefill_metadata = FlashAttentionMetadata(
  228. num_prefills=self.num_prefills,
  229. num_prefill_tokens=self.num_prefill_tokens,
  230. num_decode_tokens=0,
  231. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  232. seq_lens=self.seq_lens[:self.num_prefills],
  233. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  234. max_query_len=self.max_query_len,
  235. max_prefill_seq_len=self.max_prefill_seq_len,
  236. max_decode_seq_len=0,
  237. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  238. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  239. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  240. block_tables=self.block_tables[:self.num_prefills],
  241. use_cuda_graph=False,
  242. )
  243. return self._cached_prefill_metadata
  244. @property
  245. def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
  246. if self.num_decode_tokens == 0:
  247. return None
  248. if self._cached_decode_metadata is not None:
  249. return self._cached_decode_metadata
  250. assert self.block_tables is not None
  251. assert self.seq_lens_tensor is not None
  252. self._cached_decode_metadata = FlashAttentionMetadata(
  253. num_prefills=0,
  254. num_prefill_tokens=0,
  255. num_decode_tokens=self.num_decode_tokens,
  256. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  257. seq_lens=None,
  258. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  259. max_query_len=None,
  260. max_prefill_seq_len=0,
  261. max_decode_seq_len=self.max_decode_seq_len,
  262. query_start_loc=None,
  263. seq_start_loc=None,
  264. context_lens_tensor=None,
  265. block_tables=self.block_tables[self.num_prefills:],
  266. use_cuda_graph=self.use_cuda_graph,
  267. )
  268. return self._cached_decode_metadata
  269. def advance_step(self, num_seqs: int, num_queries: int):
  270. """
  271. Update metadata in-place to advance one decode step.
  272. """
  273. # GPU in-place update is currently called separately through
  274. # custom_ops.advance_step(). See draft_model_runner.
  275. # TODO: Move this logic to the backend.
  276. # When using cudagraph, the num_seqs is padded to the next captured
  277. # batch sized, but num_queries tracks the actual number of requests in
  278. # the batch. For --enforce-eager mode, num_seqs == num_queries
  279. if num_seqs != num_queries:
  280. assert num_seqs > num_queries
  281. assert self.use_cuda_graph
  282. assert self.num_prefills == 0
  283. assert self.num_prefill_tokens == 0
  284. assert self.num_decode_tokens == num_seqs
  285. assert self.slot_mapping.shape == (num_seqs, )
  286. assert self.seq_lens is not None
  287. assert len(self.seq_lens) == num_seqs
  288. assert self.seq_lens_tensor is not None
  289. assert self.seq_lens_tensor.shape == (num_seqs, )
  290. assert self.max_query_len == 1
  291. assert self.max_prefill_seq_len == 0
  292. assert self.max_decode_seq_len == max(self.seq_lens)
  293. assert self.query_start_loc is not None
  294. assert self.query_start_loc.shape == (num_queries + 1, )
  295. assert self.seq_start_loc is not None
  296. assert self.seq_start_loc.shape == (num_seqs + 1, )
  297. assert self.context_lens_tensor is not None
  298. assert self.context_lens_tensor.shape == (num_queries, )
  299. assert self.block_tables is not None
  300. assert self.block_tables.shape[0] == num_seqs
  301. # Update query lengths. Note that we update only queries and not seqs,
  302. # since tensors may be padded due to captured cuda graph batch size
  303. for i in range(num_queries):
  304. self.seq_lens[i] += 1
  305. self.max_decode_seq_len = max(self.seq_lens)
  306. class FlashAttentionMetadataBuilder(
  307. AttentionMetadataBuilder[FlashAttentionMetadata]):
  308. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  309. self.slot_mapping: List[int] = []
  310. self.prefill_seq_lens: List[int] = []
  311. self.context_lens: List[int] = []
  312. self.block_tables: List[List[int]] = []
  313. self.curr_seq_lens: List[int] = []
  314. self.num_prefills = 0
  315. self.num_prefill_tokens = 0
  316. self.num_decode_tokens = 0
  317. self.has_prefix_cache_hit = False
  318. self.input_builder = input_builder
  319. self.runner = input_builder.runner
  320. self.sliding_window = input_builder.sliding_window
  321. self.block_size = input_builder.block_size
  322. self.use_v2_block_manager = (
  323. input_builder.scheduler_config.use_v2_block_manager)
  324. def _add_seq_group(
  325. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  326. chunked_prefill_enabled: bool, prefix_cache_hit: bool):
  327. """Add a sequence group to the metadata. Specifically update/append
  328. 1. context length.
  329. 2. block table.
  330. 3. slot mapping.
  331. """
  332. is_prompt = inter_data.is_prompt
  333. block_tables = inter_data.block_tables
  334. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  335. curr_sliding_window_block) in zip(
  336. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  337. inter_data.orig_seq_lens, inter_data.seq_lens,
  338. inter_data.query_lens, inter_data.context_lens,
  339. inter_data.curr_sliding_window_blocks):
  340. self.context_lens.append(context_len)
  341. if is_prompt:
  342. self.num_prefills += 1
  343. self.num_prefill_tokens += token_len
  344. self.prefill_seq_lens.append(seq_len)
  345. else:
  346. assert query_len == 1, (
  347. "seq_len: {}, context_len: {}, query_len: {}".format(
  348. seq_len, context_len, query_len))
  349. self.num_decode_tokens += query_len
  350. self.curr_seq_lens.append(curr_seq_len)
  351. # Compute block table.
  352. # TODO: Combine chunked prefill and prefix caching by
  353. # only allowing multiple of block_size chunk size.
  354. # NOTE: This only works for oooooooxxx style attention.
  355. block_table = []
  356. if prefix_cache_hit:
  357. # NOTE: For flash-attn, the block table should
  358. # include the entries for the incoming prefill tokens.
  359. block_table = block_tables[seq_id]
  360. elif ((chunked_prefill_enabled or not is_prompt)
  361. and block_tables is not None):
  362. if curr_sliding_window_block == 0:
  363. block_table = block_tables[seq_id]
  364. else:
  365. block_table = block_tables[seq_id][
  366. -curr_sliding_window_block:]
  367. self.block_tables.append(block_table)
  368. # Compute slot mapping.
  369. is_profile_run = is_block_tables_empty(block_tables)
  370. start_idx = compute_slot_mapping_start_idx(
  371. is_prompt, query_len, context_len, self.sliding_window,
  372. self.use_v2_block_manager)
  373. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  374. seq_len, context_len, start_idx,
  375. self.block_size, inter_data.block_tables)
  376. def build(self, seq_lens: List[int], query_lens: List[int],
  377. cuda_graph_pad_size: int, batch_size: int):
  378. """Build attention metadata with on-device tensors.
  379. Args:
  380. seq_lens: The maybe padded sequence lengths of the input sequences.
  381. query_lens: The query lengths of the input sequences.
  382. cuda_graph_pad_size: The padding size for cuda graph.
  383. -1 if cuda graph is not used.
  384. batch_size: The maybe padded batch size.
  385. """
  386. prefix_cache_hit = any([
  387. inter_data.prefix_cache_hit
  388. for inter_data in self.input_builder.inter_data_list
  389. ])
  390. for inter_data in self.input_builder.inter_data_list:
  391. self._add_seq_group(inter_data,
  392. self.input_builder.chunked_prefill_enabled,
  393. prefix_cache_hit)
  394. device = self.runner.device
  395. use_captured_graph = cuda_graph_pad_size != -1
  396. max_query_len = max(query_lens)
  397. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  398. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  399. num_decode_tokens = self.num_decode_tokens
  400. if use_captured_graph:
  401. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  402. self.block_tables.extend([] * cuda_graph_pad_size)
  403. num_decode_tokens = batch_size
  404. # The shape of graph_block_tables is
  405. # [max batch size, max context len // block size].
  406. input_block_tables = self.runner.graph_block_tables[:batch_size]
  407. for i, block_table in enumerate(self.block_tables):
  408. if block_table:
  409. input_block_tables[i, :len(block_table)] = block_table
  410. block_tables = torch.from_numpy(input_block_tables).to(
  411. device=device, non_blocking=True)
  412. else:
  413. block_tables = make_tensor_with_pad(
  414. self.block_tables,
  415. pad=0,
  416. dtype=torch.int,
  417. device=device,
  418. )
  419. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  420. assert device is not None
  421. context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
  422. device, self.runner.pin_memory)
  423. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  424. self.runner.pin_memory)
  425. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  426. self.runner.pin_memory)
  427. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  428. device, self.runner.pin_memory)
  429. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  430. dtype=torch.int32,
  431. device=device)
  432. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  433. dtype=torch.int32,
  434. device=device)
  435. torch.cumsum(seq_lens_tensor,
  436. dim=0,
  437. dtype=seq_start_loc.dtype,
  438. out=seq_start_loc[1:])
  439. torch.cumsum(query_lens_tensor,
  440. dim=0,
  441. dtype=query_start_loc.dtype,
  442. out=query_start_loc[1:])
  443. return FlashAttentionMetadata(
  444. num_prefills=self.num_prefills,
  445. slot_mapping=slot_mapping_tensor,
  446. num_prefill_tokens=self.num_prefill_tokens,
  447. num_decode_tokens=num_decode_tokens,
  448. seq_lens=seq_lens,
  449. seq_lens_tensor=seq_lens_tensor,
  450. max_query_len=max_query_len,
  451. max_prefill_seq_len=max_prefill_seq_len,
  452. max_decode_seq_len=max_decode_seq_len,
  453. query_start_loc=query_start_loc,
  454. seq_start_loc=seq_start_loc,
  455. context_lens_tensor=context_lens_tensor,
  456. block_tables=block_tables,
  457. use_cuda_graph=use_captured_graph,
  458. )
  459. class FlashAttentionImpl(AttentionImpl):
  460. """
  461. If the input tensors contain prompt tokens, the layout is as follows:
  462. |<--------------- num_prefill_tokens ----------------->|
  463. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  464. Otherwise, the layout is as follows:
  465. |<----------------- num_decode_tokens ------------------>|
  466. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  467. Generation tokens can contain padding when cuda-graph is used.
  468. Currently, prompt tokens don't contain any padding.
  469. The prompts might have different lengths, while the generation tokens
  470. always have length 1.
  471. If chunked prefill is enabled, prefill tokens and decode tokens can be
  472. batched together in a flattened 1D query.
  473. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  474. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  475. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  476. padding between prefill and decode tokens.
  477. """
  478. def __init__(
  479. self,
  480. num_heads: int,
  481. head_size: int,
  482. scale: float,
  483. num_kv_heads: int,
  484. alibi_slopes: Optional[List[float]],
  485. sliding_window: Optional[int],
  486. kv_cache_dtype: str,
  487. blocksparse_params: Optional[Dict[str, Any]] = None,
  488. logits_soft_cap: Optional[float] = None,
  489. ) -> None:
  490. if blocksparse_params is not None:
  491. raise ValueError(
  492. "FlashAttention does not support block-sparse attention.")
  493. self.num_heads = num_heads
  494. self.head_size = head_size
  495. self.scale = float(scale)
  496. self.num_kv_heads = num_kv_heads
  497. if alibi_slopes is not None:
  498. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  499. self.alibi_slopes = alibi_slopes
  500. self.sliding_window = ((sliding_window, sliding_window)
  501. if sliding_window is not None else (-1, -1))
  502. self.kv_cache_dtype = kv_cache_dtype
  503. if logits_soft_cap is None:
  504. # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
  505. logits_soft_cap = 0
  506. self.logits_soft_cap = logits_soft_cap
  507. assert self.num_heads % self.num_kv_heads == 0
  508. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  509. if sliding_window is not None:
  510. # NOTE: flash-attn's sliding window does not work with
  511. # paged KV cache.
  512. raise ValueError(
  513. "Sliding window is not supported in FlashAttention.")
  514. support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
  515. if head_size not in support_head_sizes:
  516. raise ValueError(
  517. f"Head size {head_size} is not supported by FlashAttention. "
  518. f"Supported head sizes are: {support_head_sizes}.")
  519. def forward(
  520. self,
  521. query: torch.Tensor,
  522. key: torch.Tensor,
  523. value: torch.Tensor,
  524. kv_cache: torch.Tensor,
  525. attn_metadata: FlashAttentionMetadata,
  526. k_scale: float = 1.0,
  527. v_scale: float = 1.0,
  528. attn_type: AttentionType = AttentionType.DECODER,
  529. ) -> torch.Tensor:
  530. """Forward pass with FlashAttention.
  531. Args:
  532. query: shape = [num_tokens, num_heads * head_size]
  533. key: shape = [num_tokens, num_kv_heads * head_size]
  534. value: shape = [num_tokens, num_kv_heads * head_size]
  535. kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
  536. attn_metadata: Metadata for attention.
  537. Returns:
  538. shape = [num_tokens, num_heads * head_size]
  539. """
  540. if attn_type != AttentionType.DECODER:
  541. raise NotImplementedError("Encoder self-attention and "
  542. "encoder/decoder cross-attention "
  543. "are not implemented for "
  544. "FlashAttentionImpl")
  545. # NOTE: FlashAttention does not support FP8 KV cache.
  546. assert k_scale == 1.0 and v_scale == 1.0, (
  547. "key/v_scale is not supported in FlashAttention.")
  548. num_tokens, hidden_size = query.shape
  549. # Reshape the query, key, and value tensors.
  550. query = query.view(-1, self.num_heads, self.head_size)
  551. key = key.view(-1, self.num_kv_heads, self.head_size)
  552. value = value.view(-1, self.num_kv_heads, self.head_size)
  553. if kv_cache is not None:
  554. key_cache = kv_cache[0]
  555. value_cache = kv_cache[1]
  556. # Reshape the input keys and values and store them in the cache.
  557. # If kv_cache is not provided, the new key and value tensors are
  558. # not cached. This happens during the initial memory profiling run.
  559. ops.reshape_and_cache_flash(
  560. key,
  561. value,
  562. key_cache,
  563. value_cache,
  564. attn_metadata.slot_mapping.flatten(),
  565. self.kv_cache_dtype,
  566. k_scale,
  567. v_scale,
  568. )
  569. num_prefill_tokens = attn_metadata.num_prefill_tokens
  570. num_decode_tokens = attn_metadata.num_decode_tokens
  571. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  572. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  573. output = torch.empty_like(query)
  574. # Query for decode. KV is not needed because it is already cached.
  575. decode_query = query[num_prefill_tokens:]
  576. # QKV for prefill.
  577. query = query[:num_prefill_tokens]
  578. key = key[:num_prefill_tokens]
  579. value = value[:num_prefill_tokens]
  580. assert query.shape[0] == num_prefill_tokens
  581. assert decode_query.shape[0] == num_decode_tokens
  582. if prefill_meta := attn_metadata.prefill_metadata:
  583. # Prompt run.
  584. if (kv_cache is None or prefill_meta.block_tables is None
  585. or prefill_meta.block_tables.numel() == 0):
  586. # normal attention
  587. # When block_tables are not filled, it means q and k are the
  588. # prompt, and they have the same length.
  589. out = torch.ops.aphrodite.flash_attn_varlen_func(
  590. q=query,
  591. k=key,
  592. v=value,
  593. cu_seqlens_q=prefill_meta.seq_start_loc,
  594. cu_seqlens_k=prefill_meta.seq_start_loc,
  595. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  596. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  597. softmax_scale=self.scale,
  598. causal=True,
  599. window_size=self.sliding_window,
  600. alibi_slopes=self.alibi_slopes,
  601. softcap=self.logits_soft_cap,
  602. )
  603. assert output[:num_prefill_tokens].shape == out.shape
  604. output[:num_prefill_tokens] = out
  605. else:
  606. # prefix-enabled attention
  607. assert prefill_meta.seq_lens is not None
  608. max_seq_len = max(prefill_meta.seq_lens)
  609. output[:
  610. num_prefill_tokens] = torch.ops.aphrodite.flash_attn_varlen_func( # noqa
  611. q=query,
  612. k=key_cache,
  613. v=value_cache,
  614. cu_seqlens_q=prefill_meta.query_start_loc,
  615. max_seqlen_q=prefill_meta.max_query_len,
  616. cu_seqlens_k=prefill_meta.seq_start_loc,
  617. max_seqlen_k=max_seq_len,
  618. softmax_scale=self.scale,
  619. causal=True,
  620. alibi_slopes=self.alibi_slopes,
  621. block_table=prefill_meta.block_tables,
  622. softcap=self.logits_soft_cap,
  623. )
  624. if decode_meta := attn_metadata.decode_metadata:
  625. # Decoding run.
  626. output[
  627. num_prefill_tokens:] = torch.ops.aphrodite.flash_attn_with_kvcache( # noqa
  628. decode_query.unsqueeze(1),
  629. key_cache,
  630. value_cache,
  631. block_table=decode_meta.block_tables,
  632. cache_seqlens=decode_meta.seq_lens_tensor,
  633. softmax_scale=self.scale,
  634. causal=True,
  635. alibi_slopes=self.alibi_slopes,
  636. softcap=self.logits_soft_cap,
  637. ).squeeze(1)
  638. # Reshape the output tensor.
  639. return output.view(num_tokens, hidden_size)