flash_attn.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. """Attention layer with FlashAttention."""
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.attention.backends.abstract import (AttentionBackend,
  7. AttentionImpl,
  8. AttentionMetadata,
  9. AttentionMetadataBuilder,
  10. AttentionType)
  11. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  12. CommonAttentionState,
  13. compute_slot_mapping,
  14. compute_slot_mapping_start_idx,
  15. is_block_tables_empty)
  16. from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
  17. if TYPE_CHECKING:
  18. from aphrodite.worker.model_runner import (ModelInputForGPUBuilder,
  19. ModelInputForGPUWithSamplingMetadata)
  20. from aphrodite_flash_attn import (
  21. flash_attn_varlen_func as _flash_attn_varlen_func)
  22. from aphrodite_flash_attn import (
  23. flash_attn_with_kvcache as _flash_attn_with_kvcache)
  24. @torch.library.custom_op("aphrodite::flash_attn_varlen_func", mutates_args=[])
  25. def flash_attn_varlen_func(
  26. q: torch.Tensor,
  27. k: torch.Tensor,
  28. v: torch.Tensor,
  29. cu_seqlens_q: torch.Tensor,
  30. cu_seqlens_k: torch.Tensor,
  31. max_seqlen_q: int,
  32. max_seqlen_k: int,
  33. softmax_scale: Optional[float] = None,
  34. causal: bool = False,
  35. window_size: Optional[List[int]] = None,
  36. softcap: float = 0.0,
  37. alibi_slopes: Optional[torch.Tensor] = None,
  38. block_table: Optional[torch.Tensor] = None,
  39. ) -> torch.Tensor:
  40. # custom op does not support tuple input
  41. real_window_size: Tuple[int, int]
  42. if window_size is None:
  43. real_window_size = (-1, -1)
  44. else:
  45. assert len(window_size) == 2
  46. real_window_size = (window_size[0], window_size[1])
  47. return _flash_attn_varlen_func(
  48. q=q,
  49. k=k,
  50. v=v,
  51. cu_seqlens_q=cu_seqlens_q,
  52. cu_seqlens_k=cu_seqlens_k,
  53. max_seqlen_q=max_seqlen_q,
  54. max_seqlen_k=max_seqlen_k,
  55. softmax_scale=softmax_scale,
  56. causal=causal,
  57. window_size=real_window_size,
  58. softcap=softcap,
  59. alibi_slopes=alibi_slopes,
  60. block_table=block_table,
  61. )
  62. @flash_attn_varlen_func.register_fake # type: ignore
  63. def _(
  64. q: torch.Tensor,
  65. k: torch.Tensor,
  66. v: torch.Tensor,
  67. cu_seqlens_q: torch.Tensor,
  68. cu_seqlens_k: torch.Tensor,
  69. max_seqlen_q: int,
  70. max_seqlen_k: int,
  71. softmax_scale: Optional[float] = None,
  72. causal: bool = False,
  73. window_size: Optional[List[int]] = None,
  74. softcap: float = 0.0,
  75. alibi_slopes: Optional[torch.Tensor] = None,
  76. block_table: Optional[torch.Tensor] = None,
  77. ) -> torch.Tensor:
  78. return torch.empty_like(q)
  79. @torch.library.custom_op("aphrodite::flash_attn_with_kvcache", mutates_args=[])
  80. def flash_attn_with_kvcache(
  81. decode_query: torch.Tensor,
  82. key_cache: torch.Tensor,
  83. value_cache: torch.Tensor,
  84. cache_seqlens: Optional[torch.Tensor] = None,
  85. block_table: Optional[torch.Tensor] = None,
  86. softmax_scale: Optional[float] = None,
  87. causal: bool = False,
  88. alibi_slopes: Optional[torch.Tensor] = None,
  89. softcap: float = 0.0,
  90. ) -> torch.Tensor:
  91. return _flash_attn_with_kvcache(
  92. decode_query,
  93. key_cache,
  94. value_cache,
  95. cache_seqlens=cache_seqlens,
  96. block_table=block_table,
  97. softmax_scale=softmax_scale,
  98. causal=causal,
  99. alibi_slopes=alibi_slopes,
  100. softcap=softcap,
  101. )
  102. @flash_attn_with_kvcache.register_fake # type: ignore
  103. def _(
  104. decode_query: torch.Tensor,
  105. key_cache: torch.Tensor,
  106. value_cache: torch.Tensor,
  107. cache_seqlens: Optional[torch.Tensor] = None,
  108. block_table: Optional[torch.Tensor] = None,
  109. softmax_scale: Optional[float] = None,
  110. causal: bool = False,
  111. alibi_slopes: Optional[torch.Tensor] = None,
  112. softcap: float = 0.0,
  113. ) -> torch.Tensor:
  114. return torch.empty_like(decode_query)
  115. @torch.library.custom_op("aphrodite::reshape_and_cache_flash",
  116. mutates_args=["kv_cache"])
  117. def reshape_and_cache_flash(
  118. key: torch.Tensor,
  119. value: torch.Tensor,
  120. kv_cache: torch.Tensor,
  121. slot_mapping: torch.Tensor,
  122. kv_cache_dtype: str,
  123. k_scale: float,
  124. v_scale: float,
  125. ) -> None:
  126. """Inductor cannot deal with inplace operations on views.
  127. See https://github.com/pytorch/pytorch/issues/131192
  128. and https://github.com/pytorch/pytorch/issues/130174
  129. This is a workaround to hide the view operation from the inductor.
  130. """
  131. return torch.ops._C_cache_ops.reshape_and_cache_flash(
  132. key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
  133. k_scale, v_scale)
  134. @reshape_and_cache_flash.register_fake # type: ignore
  135. def _(
  136. key: torch.Tensor,
  137. value: torch.Tensor,
  138. kv_cache: torch.Tensor,
  139. slot_mapping: torch.Tensor,
  140. kv_cache_dtype: str,
  141. k_scale: float,
  142. v_scale: float,
  143. ) -> None:
  144. pass
  145. class FlashAttentionBackend(AttentionBackend):
  146. @staticmethod
  147. def get_supported_head_sizes() -> List[int]:
  148. return [32, 64, 96, 128, 160, 192, 224, 256]
  149. @staticmethod
  150. def get_name() -> str:
  151. return "flash-attn"
  152. @staticmethod
  153. def get_impl_cls() -> Type["FlashAttentionImpl"]:
  154. return FlashAttentionImpl
  155. @staticmethod
  156. def get_metadata_cls() -> Type["AttentionMetadata"]:
  157. return FlashAttentionMetadata
  158. @staticmethod
  159. def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
  160. return FlashAttentionMetadataBuilder
  161. @staticmethod
  162. def get_state_cls() -> Type["CommonAttentionState"]:
  163. return CommonAttentionState
  164. @staticmethod
  165. def get_kv_cache_shape(
  166. num_blocks: int,
  167. block_size: int,
  168. num_kv_heads: int,
  169. head_size: int,
  170. ) -> Tuple[int, ...]:
  171. if block_size % 16 != 0:
  172. raise ValueError("Block size must be a multiple of 16.")
  173. return (2, num_blocks, block_size, num_kv_heads, head_size)
  174. @staticmethod
  175. def swap_blocks(
  176. src_kv_cache: torch.Tensor,
  177. dst_kv_cache: torch.Tensor,
  178. src_to_dst: torch.Tensor,
  179. ) -> None:
  180. src_key_cache = src_kv_cache[0]
  181. dst_key_cache = dst_kv_cache[0]
  182. ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  183. src_value_cache = src_kv_cache[1]
  184. dst_value_cache = dst_kv_cache[1]
  185. ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  186. @staticmethod
  187. def copy_blocks(
  188. kv_caches: List[torch.Tensor],
  189. src_to_dists: torch.Tensor,
  190. ) -> None:
  191. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  192. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  193. ops.copy_blocks(key_caches, value_caches, src_to_dists)
  194. @dataclass
  195. class FlashAttentionMetadata(AttentionMetadata):
  196. """Metadata for FlashAttentionBackend.
  197. NOTE: Any python object stored here is not updated when it is
  198. cuda-graph replayed. If you have values that need to be changed
  199. dynamically, it should be stored in tensor. The tensor has to be
  200. updated from `CUDAGraphRunner.forward` API.
  201. """
  202. # (batch_size,). The sequence length per sequence. Sequence length means
  203. # the computed tokens + new tokens None if it is a decoding.
  204. seq_lens: Optional[List[int]]
  205. # seq_lens stored as a tensor.
  206. seq_lens_tensor: Optional[torch.Tensor]
  207. # NOTE: Definition of context_len, query_len, and seq_len.
  208. # |---------- N-1 iteration --------|
  209. # |---------------- N iteration ---------------------|
  210. # |- tokenA -|......................|-- newTokens ---|
  211. # |---------- context_len ----------|
  212. # |-------------------- seq_len ----------------------|
  213. # |-- query_len ---|
  214. # Maximum query length in the batch. None for decoding.
  215. max_query_len: Optional[int]
  216. # Maximum sequence length among prefill batch. 0 if there are decoding
  217. # requests only.
  218. max_prefill_seq_len: int
  219. # Maximum sequence length among decode batch. 0 if there are prefill
  220. # requests only.
  221. max_decode_seq_len: int
  222. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  223. # the batch, used to index into subquery. E.g., if the subquery length
  224. # is [4, 6], it is [0, 4, 10].
  225. query_start_loc: Optional[torch.Tensor]
  226. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  227. # the batch, used to index into sequence. E.g., if the sequence length is
  228. # [4, 6], it is [0, 4, 10].
  229. seq_start_loc: Optional[torch.Tensor]
  230. # (batch_size,) A tensor of context lengths (tokens that are computed
  231. # so far).
  232. context_lens_tensor: Optional[torch.Tensor]
  233. # (batch_size, max_blocks_per_seq).
  234. # Block addresses per sequence. (Seq id -> list of physical block)
  235. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  236. # in the kv cache. Each block can contain up to block_size tokens.
  237. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  238. # captured.
  239. block_tables: Optional[torch.Tensor]
  240. # Whether or not if cuda graph is enabled.
  241. # Cuda-graph is currently enabled for decoding only.
  242. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  243. use_cuda_graph: bool
  244. _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
  245. _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
  246. @property
  247. def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
  248. if self.num_prefills == 0:
  249. return None
  250. if self._cached_prefill_metadata is not None:
  251. return self._cached_prefill_metadata
  252. assert self.seq_lens is not None
  253. assert self.seq_lens_tensor is not None
  254. assert self.query_start_loc is not None
  255. assert self.context_lens_tensor is not None
  256. assert self.block_tables is not None
  257. assert self.seq_start_loc is not None
  258. self._cached_prefill_metadata = FlashAttentionMetadata(
  259. num_prefills=self.num_prefills,
  260. num_prefill_tokens=self.num_prefill_tokens,
  261. num_decode_tokens=0,
  262. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  263. seq_lens=self.seq_lens[:self.num_prefills],
  264. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  265. max_query_len=self.max_query_len,
  266. max_prefill_seq_len=self.max_prefill_seq_len,
  267. max_decode_seq_len=0,
  268. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  269. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  270. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  271. block_tables=self.block_tables[:self.num_prefills],
  272. use_cuda_graph=False,
  273. )
  274. return self._cached_prefill_metadata
  275. @property
  276. def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
  277. if self.num_decode_tokens == 0:
  278. return None
  279. if self._cached_decode_metadata is not None:
  280. return self._cached_decode_metadata
  281. assert self.block_tables is not None
  282. assert self.seq_lens_tensor is not None
  283. self._cached_decode_metadata = FlashAttentionMetadata(
  284. num_prefills=0,
  285. num_prefill_tokens=0,
  286. num_decode_tokens=self.num_decode_tokens,
  287. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  288. seq_lens=None,
  289. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  290. max_query_len=None,
  291. max_prefill_seq_len=0,
  292. max_decode_seq_len=self.max_decode_seq_len,
  293. query_start_loc=None,
  294. seq_start_loc=None,
  295. context_lens_tensor=None,
  296. block_tables=self.block_tables[self.num_prefills:],
  297. use_cuda_graph=self.use_cuda_graph,
  298. )
  299. return self._cached_decode_metadata
  300. def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
  301. sampled_token_ids: Optional[torch.Tensor],
  302. block_size: int, num_seqs: int, num_queries: int):
  303. """
  304. Update metadata in-place to advance one decode step.
  305. """
  306. # When using cudagraph, the num_seqs is padded to the next captured
  307. # batch sized, but num_queries tracks the actual number of requests in
  308. # the batch. For --enforce-eager mode, num_seqs == num_queries
  309. if num_seqs != num_queries:
  310. assert num_seqs > num_queries
  311. assert self.use_cuda_graph
  312. assert self.num_prefills == 0
  313. assert self.num_prefill_tokens == 0
  314. assert self.num_decode_tokens == num_seqs
  315. assert self.slot_mapping.shape == (num_seqs, )
  316. assert self.seq_lens is not None
  317. assert len(self.seq_lens) == num_seqs
  318. assert self.seq_lens_tensor is not None
  319. assert self.seq_lens_tensor.shape == (num_seqs, )
  320. assert self.max_query_len == 1
  321. assert self.max_prefill_seq_len == 0
  322. assert self.max_decode_seq_len == max(self.seq_lens)
  323. assert self.query_start_loc is not None
  324. assert self.query_start_loc.shape == (num_queries + 1, )
  325. assert self.seq_start_loc is not None
  326. assert self.seq_start_loc.shape == (num_seqs + 1, )
  327. assert self.context_lens_tensor is not None
  328. assert self.context_lens_tensor.shape == (num_queries, )
  329. assert self.block_tables is not None
  330. assert self.block_tables.shape[0] == num_seqs
  331. # Update query lengths. Note that we update only queries and not seqs,
  332. # since tensors may be padded due to captured cuda graph batch size
  333. for i in range(num_queries):
  334. self.seq_lens[i] += 1
  335. self.max_decode_seq_len = max(self.seq_lens)
  336. ops.advance_step_flashattn(num_seqs=num_seqs,
  337. num_queries=num_queries,
  338. block_size=block_size,
  339. input_tokens=model_input.input_tokens,
  340. sampled_token_ids=sampled_token_ids,
  341. input_positions=model_input.input_positions,
  342. seq_lens=self.seq_lens_tensor,
  343. slot_mapping=self.slot_mapping,
  344. block_tables=self.block_tables)
  345. class FlashAttentionMetadataBuilder(
  346. AttentionMetadataBuilder[FlashAttentionMetadata]):
  347. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  348. self.slot_mapping: List[int] = []
  349. self.prefill_seq_lens: List[int] = []
  350. self.context_lens: List[int] = []
  351. self.block_tables: List[List[int]] = []
  352. self.curr_seq_lens: List[int] = []
  353. self.num_prefills = 0
  354. self.num_prefill_tokens = 0
  355. self.num_decode_tokens = 0
  356. self.has_prefix_cache_hit = False
  357. self.input_builder = input_builder
  358. self.runner = input_builder.runner
  359. self.sliding_window = input_builder.sliding_window
  360. self.block_size = input_builder.block_size
  361. self.use_v2_block_manager = (
  362. input_builder.scheduler_config.use_v2_block_manager)
  363. def _add_seq_group(
  364. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  365. chunked_prefill_enabled: bool, prefix_cache_hit: bool):
  366. """Add a sequence group to the metadata. Specifically update/append
  367. 1. context length.
  368. 2. block table.
  369. 3. slot mapping.
  370. """
  371. is_prompt = inter_data.is_prompt
  372. block_tables = inter_data.block_tables
  373. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  374. curr_sliding_window_block) in zip(
  375. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  376. inter_data.orig_seq_lens, inter_data.seq_lens,
  377. inter_data.query_lens, inter_data.context_lens,
  378. inter_data.curr_sliding_window_blocks):
  379. self.context_lens.append(context_len)
  380. if is_prompt:
  381. self.num_prefills += 1
  382. self.num_prefill_tokens += token_len
  383. self.prefill_seq_lens.append(seq_len)
  384. else:
  385. assert query_len == 1, (
  386. "seq_len: {}, context_len: {}, query_len: {}".format(
  387. seq_len, context_len, query_len))
  388. self.num_decode_tokens += query_len
  389. self.curr_seq_lens.append(curr_seq_len)
  390. # Compute block table.
  391. # TODO: Combine chunked prefill and prefix caching by
  392. # only allowing multiple of block_size chunk size.
  393. # NOTE: This only works for oooooooxxx style attention.
  394. block_table = []
  395. if prefix_cache_hit:
  396. # NOTE: For flash-attn, the block table should
  397. # include the entries for the incoming prefill tokens.
  398. block_table = block_tables[seq_id]
  399. elif ((chunked_prefill_enabled or not is_prompt)
  400. and block_tables is not None):
  401. if curr_sliding_window_block == 0:
  402. block_table = block_tables[seq_id]
  403. else:
  404. block_table = block_tables[seq_id][
  405. -curr_sliding_window_block:]
  406. self.block_tables.append(block_table)
  407. # Compute slot mapping.
  408. is_profile_run = is_block_tables_empty(block_tables)
  409. start_idx = compute_slot_mapping_start_idx(
  410. is_prompt, query_len, context_len, self.sliding_window,
  411. self.use_v2_block_manager)
  412. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  413. seq_len, context_len, start_idx,
  414. self.block_size, inter_data.block_tables)
  415. def build(self, seq_lens: List[int], query_lens: List[int],
  416. cuda_graph_pad_size: int, batch_size: int):
  417. """Build attention metadata with on-device tensors.
  418. Args:
  419. seq_lens: The maybe padded sequence lengths of the input sequences.
  420. query_lens: The query lengths of the input sequences.
  421. cuda_graph_pad_size: The padding size for cuda graph.
  422. -1 if cuda graph is not used.
  423. batch_size: The maybe padded batch size.
  424. """
  425. prefix_cache_hit = any([
  426. inter_data.prefix_cache_hit
  427. for inter_data in self.input_builder.inter_data_list
  428. ])
  429. for inter_data in self.input_builder.inter_data_list:
  430. self._add_seq_group(inter_data,
  431. self.input_builder.chunked_prefill_enabled,
  432. prefix_cache_hit)
  433. device = self.runner.device
  434. use_captured_graph = cuda_graph_pad_size != -1
  435. max_query_len = max(query_lens)
  436. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  437. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  438. num_decode_tokens = self.num_decode_tokens
  439. if use_captured_graph:
  440. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  441. self.block_tables.extend([] * cuda_graph_pad_size)
  442. num_decode_tokens = batch_size
  443. # The shape of graph_block_tables is
  444. # [max batch size, max context len // block size].
  445. input_block_tables = self.runner.graph_block_tables[:batch_size]
  446. max_blocks = input_block_tables.shape[1]
  447. for i, block_table in enumerate(self.block_tables):
  448. if block_table:
  449. num_blocks = len(block_table)
  450. if num_blocks <= max_blocks:
  451. input_block_tables[i, :num_blocks] = block_table
  452. else:
  453. # It may be possible to have more blocks allocated due
  454. # to lookahead slots of multi-step, however, they are
  455. # not used anyway, so can be safely ignored.
  456. input_block_tables[
  457. i, :max_blocks] = block_table[:max_blocks]
  458. block_tables = torch.from_numpy(input_block_tables).to(
  459. device=device, non_blocking=True)
  460. else:
  461. block_tables = make_tensor_with_pad(
  462. self.block_tables,
  463. pad=0,
  464. dtype=torch.int,
  465. device=device,
  466. )
  467. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  468. assert device is not None
  469. context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
  470. device, self.runner.pin_memory)
  471. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  472. self.runner.pin_memory)
  473. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  474. self.runner.pin_memory)
  475. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  476. device, self.runner.pin_memory)
  477. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  478. dtype=torch.int32,
  479. device=device)
  480. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  481. dtype=torch.int32,
  482. device=device)
  483. torch.cumsum(seq_lens_tensor,
  484. dim=0,
  485. dtype=seq_start_loc.dtype,
  486. out=seq_start_loc[1:])
  487. torch.cumsum(query_lens_tensor,
  488. dim=0,
  489. dtype=query_start_loc.dtype,
  490. out=query_start_loc[1:])
  491. return FlashAttentionMetadata(
  492. num_prefills=self.num_prefills,
  493. slot_mapping=slot_mapping_tensor,
  494. num_prefill_tokens=self.num_prefill_tokens,
  495. num_decode_tokens=num_decode_tokens,
  496. seq_lens=seq_lens,
  497. seq_lens_tensor=seq_lens_tensor,
  498. max_query_len=max_query_len,
  499. max_prefill_seq_len=max_prefill_seq_len,
  500. max_decode_seq_len=max_decode_seq_len,
  501. query_start_loc=query_start_loc,
  502. seq_start_loc=seq_start_loc,
  503. context_lens_tensor=context_lens_tensor,
  504. block_tables=block_tables,
  505. use_cuda_graph=use_captured_graph,
  506. )
  507. class FlashAttentionImpl(AttentionImpl):
  508. """
  509. If the input tensors contain prompt tokens, the layout is as follows:
  510. |<--------------- num_prefill_tokens ----------------->|
  511. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  512. Otherwise, the layout is as follows:
  513. |<----------------- num_decode_tokens ------------------>|
  514. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  515. Generation tokens can contain padding when cuda-graph is used.
  516. Currently, prompt tokens don't contain any padding.
  517. The prompts might have different lengths, while the generation tokens
  518. always have length 1.
  519. If chunked prefill is enabled, prefill tokens and decode tokens can be
  520. batched together in a flattened 1D query.
  521. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  522. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  523. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  524. padding between prefill and decode tokens.
  525. """
  526. def __init__(
  527. self,
  528. num_heads: int,
  529. head_size: int,
  530. scale: float,
  531. num_kv_heads: int,
  532. alibi_slopes: Optional[List[float]],
  533. sliding_window: Optional[int],
  534. kv_cache_dtype: str,
  535. blocksparse_params: Optional[Dict[str, Any]] = None,
  536. logits_soft_cap: Optional[float] = None,
  537. ) -> None:
  538. if blocksparse_params is not None:
  539. raise ValueError(
  540. "FlashAttention does not support block-sparse attention.")
  541. self.num_heads = num_heads
  542. self.head_size = head_size
  543. self.scale = float(scale)
  544. self.num_kv_heads = num_kv_heads
  545. if alibi_slopes is not None:
  546. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  547. self.alibi_slopes = alibi_slopes
  548. self.sliding_window = ((sliding_window, sliding_window)
  549. if sliding_window is not None else (-1, -1))
  550. self.kv_cache_dtype = kv_cache_dtype
  551. if logits_soft_cap is None:
  552. # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
  553. logits_soft_cap = 0
  554. self.logits_soft_cap = logits_soft_cap
  555. assert self.num_heads % self.num_kv_heads == 0
  556. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  557. if sliding_window is not None:
  558. # NOTE: flash-attn's sliding window does not work with
  559. # paged KV cache.
  560. raise ValueError(
  561. "Sliding window is not supported in FlashAttention.")
  562. support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
  563. if head_size not in support_head_sizes:
  564. raise ValueError(
  565. f"Head size {head_size} is not supported by FlashAttention. "
  566. f"Supported head sizes are: {support_head_sizes}.")
  567. def forward(
  568. self,
  569. query: torch.Tensor,
  570. key: torch.Tensor,
  571. value: torch.Tensor,
  572. kv_cache: torch.Tensor,
  573. attn_metadata: FlashAttentionMetadata,
  574. k_scale: float = 1.0,
  575. v_scale: float = 1.0,
  576. attn_type: AttentionType = AttentionType.DECODER,
  577. ) -> torch.Tensor:
  578. """Forward pass with FlashAttention.
  579. Args:
  580. query: shape = [num_tokens, num_heads * head_size]
  581. key: shape = [num_tokens, num_kv_heads * head_size]
  582. value: shape = [num_tokens, num_kv_heads * head_size]
  583. kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
  584. attn_metadata: Metadata for attention.
  585. Returns:
  586. shape = [num_tokens, num_heads * head_size]
  587. """
  588. if attn_type != AttentionType.DECODER:
  589. raise NotImplementedError("Encoder self-attention and "
  590. "encoder/decoder cross-attention "
  591. "are not implemented for "
  592. "FlashAttentionImpl")
  593. # NOTE: FlashAttention does not support FP8 KV cache.
  594. assert k_scale == 1.0 and v_scale == 1.0, (
  595. "key/v_scale is not supported in FlashAttention.")
  596. num_tokens, hidden_size = query.shape
  597. # Reshape the query, key, and value tensors.
  598. query = query.view(-1, self.num_heads, self.head_size)
  599. key = key.view(-1, self.num_kv_heads, self.head_size)
  600. value = value.view(-1, self.num_kv_heads, self.head_size)
  601. if kv_cache is not None:
  602. key_cache = kv_cache[0]
  603. value_cache = kv_cache[1]
  604. # Reshape the input keys and values and store them in the cache.
  605. # If kv_cache is not provided, the new key and value tensors are
  606. # not cached. This happens during the initial memory profiling run.
  607. torch.ops.aphrodite.reshape_and_cache_flash(
  608. key,
  609. value,
  610. kv_cache,
  611. attn_metadata.slot_mapping.flatten(),
  612. self.kv_cache_dtype,
  613. k_scale,
  614. v_scale,
  615. )
  616. num_prefill_tokens = attn_metadata.num_prefill_tokens
  617. num_decode_tokens = attn_metadata.num_decode_tokens
  618. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  619. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  620. # Query for decode. KV is not needed because it is already cached.
  621. decode_query = query[num_prefill_tokens:]
  622. # QKV for prefill.
  623. query = query[:num_prefill_tokens]
  624. key = key[:num_prefill_tokens]
  625. value = value[:num_prefill_tokens]
  626. assert query.shape[0] == num_prefill_tokens
  627. assert decode_query.shape[0] == num_decode_tokens
  628. prefill_output: Optional[torch.Tensor] = None
  629. decode_output: Optional[torch.Tensor] = None
  630. if prefill_meta := attn_metadata.prefill_metadata:
  631. # Prompt run.
  632. if (kv_cache is None or prefill_meta.block_tables is None
  633. or prefill_meta.block_tables.numel() == 0):
  634. # normal attention
  635. # When block_tables are not filled, it means q and k are the
  636. # prompt, and they have the same length.
  637. prefill_output = torch.ops.aphrodite.flash_attn_varlen_func(
  638. q=query,
  639. k=key,
  640. v=value,
  641. cu_seqlens_q=prefill_meta.seq_start_loc,
  642. cu_seqlens_k=prefill_meta.seq_start_loc,
  643. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  644. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  645. softmax_scale=self.scale,
  646. causal=True,
  647. window_size=self.sliding_window,
  648. alibi_slopes=self.alibi_slopes,
  649. softcap=self.logits_soft_cap,
  650. )
  651. else:
  652. # prefix-enabled attention
  653. assert prefill_meta.seq_lens is not None
  654. max_seq_len = max(prefill_meta.seq_lens)
  655. prefill_output = torch.ops.aphrodite.flash_attn_varlen_func( # noqa
  656. q=query,
  657. k=key_cache,
  658. v=value_cache,
  659. cu_seqlens_q=prefill_meta.query_start_loc,
  660. max_seqlen_q=prefill_meta.max_query_len,
  661. cu_seqlens_k=prefill_meta.seq_start_loc,
  662. max_seqlen_k=max_seq_len,
  663. softmax_scale=self.scale,
  664. causal=True,
  665. alibi_slopes=self.alibi_slopes,
  666. block_table=prefill_meta.block_tables,
  667. softcap=self.logits_soft_cap,
  668. )
  669. if decode_meta := attn_metadata.decode_metadata:
  670. # Decoding run.
  671. decode_output = torch.ops.aphrodite.flash_attn_with_kvcache(
  672. decode_query.unsqueeze(1),
  673. key_cache,
  674. value_cache,
  675. block_table=decode_meta.block_tables,
  676. cache_seqlens=decode_meta.seq_lens_tensor,
  677. softmax_scale=self.scale,
  678. causal=True,
  679. alibi_slopes=self.alibi_slopes,
  680. softcap=self.logits_soft_cap,
  681. ).squeeze(1)
  682. if prefill_output is None:
  683. assert decode_output is not None
  684. return decode_output.view(num_decode_tokens, hidden_size)
  685. if decode_output is None:
  686. assert prefill_output is not None
  687. return prefill_output.view(num_prefill_tokens, hidden_size)
  688. output = torch.cat([prefill_output, decode_output], dim=0)
  689. return output.view(num_tokens, hidden_size)