flash_attn.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. """Attention layer with FlashAttention."""
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from aphrodite_flash_attn import (flash_attn_varlen_func,
  6. flash_attn_with_kvcache)
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.attention.backends.abstract import (AttentionBackend,
  9. AttentionImpl,
  10. AttentionMetadata,
  11. AttentionMetadataBuilder,
  12. AttentionType)
  13. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  14. compute_slot_mapping,
  15. compute_slot_mapping_start_idx,
  16. is_block_tables_empty)
  17. from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
  18. if TYPE_CHECKING:
  19. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  20. class FlashAttentionBackend(AttentionBackend):
  21. @staticmethod
  22. def get_supported_head_sizes() -> List[int]:
  23. return [32, 64, 96, 128, 160, 192, 224, 256]
  24. @staticmethod
  25. def get_name() -> str:
  26. return "flash-attn"
  27. @staticmethod
  28. def get_impl_cls() -> Type["FlashAttentionImpl"]:
  29. return FlashAttentionImpl
  30. @staticmethod
  31. def get_metadata_cls() -> Type["AttentionMetadata"]:
  32. return FlashAttentionMetadata
  33. @staticmethod
  34. def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
  35. return FlashAttentionMetadataBuilder
  36. @staticmethod
  37. def get_kv_cache_shape(
  38. num_blocks: int,
  39. block_size: int,
  40. num_kv_heads: int,
  41. head_size: int,
  42. ) -> Tuple[int, ...]:
  43. if block_size % 16 != 0:
  44. raise ValueError("Block size must be a multiple of 16.")
  45. return (2, num_blocks, block_size, num_kv_heads, head_size)
  46. @staticmethod
  47. def swap_blocks(
  48. src_kv_cache: torch.Tensor,
  49. dst_kv_cache: torch.Tensor,
  50. src_to_dst: torch.Tensor,
  51. ) -> None:
  52. src_key_cache = src_kv_cache[0]
  53. dst_key_cache = dst_kv_cache[0]
  54. ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  55. src_value_cache = src_kv_cache[1]
  56. dst_value_cache = dst_kv_cache[1]
  57. ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  58. @staticmethod
  59. def copy_blocks(
  60. kv_caches: List[torch.Tensor],
  61. src_to_dists: torch.Tensor,
  62. ) -> None:
  63. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  64. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  65. ops.copy_blocks(key_caches, value_caches, src_to_dists)
  66. @dataclass
  67. class FlashAttentionMetadata(AttentionMetadata):
  68. """Metadata for FlashAttentionBackend.
  69. NOTE: Any python object stored here is not updated when it is
  70. cuda-graph replayed. If you have values that need to be changed
  71. dynamically, it should be stored in tensor. The tensor has to be
  72. updated from `CUDAGraphRunner.forward` API.
  73. """
  74. # (batch_size,). The sequence length per sequence. Sequence length means
  75. # the computed tokens + new tokens None if it is a decoding.
  76. seq_lens: Optional[List[int]]
  77. # seq_lens stored as a tensor.
  78. seq_lens_tensor: Optional[torch.Tensor]
  79. # NOTE: Definition of context_len, query_len, and seq_len.
  80. # |---------- N-1 iteration --------|
  81. # |---------------- N iteration ---------------------|
  82. # |- tokenA -|......................|-- newTokens ---|
  83. # |---------- context_len ----------|
  84. # |-------------------- seq_len ----------------------|
  85. # |-- query_len ---|
  86. # Maximum query length in the batch. None for decoding.
  87. max_query_len: Optional[int]
  88. # Maximum sequence length among prefill batch. 0 if there are decoding
  89. # requests only.
  90. max_prefill_seq_len: int
  91. # Maximum sequence length among decode batch. 0 if there are prefill
  92. # requests only.
  93. max_decode_seq_len: int
  94. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  95. # the batch, used to index into subquery. E.g., if the subquery length
  96. # is [4, 6], it is [0, 4, 10].
  97. query_start_loc: Optional[torch.Tensor]
  98. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  99. # the batch, used to index into sequence. E.g., if the sequence length is
  100. # [4, 6], it is [0, 4, 10].
  101. seq_start_loc: Optional[torch.Tensor]
  102. # (batch_size,) A tensor of context lengths (tokens that are computed
  103. # so far).
  104. context_lens_tensor: Optional[torch.Tensor]
  105. # (batch_size, max_blocks_per_seq).
  106. # Block addresses per sequence. (Seq id -> list of physical block)
  107. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  108. # in the kv cache. Each block can contain up to block_size tokens.
  109. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  110. # captured.
  111. block_tables: Optional[torch.Tensor]
  112. # Whether or not if cuda graph is enabled.
  113. # Cuda-graph is currently enabled for decoding only.
  114. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  115. use_cuda_graph: bool
  116. _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
  117. _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
  118. @property
  119. def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
  120. if self.num_prefills == 0:
  121. return None
  122. if self._cached_prefill_metadata is not None:
  123. return self._cached_prefill_metadata
  124. assert self.seq_lens is not None
  125. assert self.seq_lens_tensor is not None
  126. assert self.query_start_loc is not None
  127. assert self.context_lens_tensor is not None
  128. assert self.block_tables is not None
  129. assert self.seq_start_loc is not None
  130. self._cached_prefill_metadata = FlashAttentionMetadata(
  131. num_prefills=self.num_prefills,
  132. num_prefill_tokens=self.num_prefill_tokens,
  133. num_decode_tokens=0,
  134. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  135. seq_lens=self.seq_lens[:self.num_prefills],
  136. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  137. max_query_len=self.max_query_len,
  138. max_prefill_seq_len=self.max_prefill_seq_len,
  139. max_decode_seq_len=0,
  140. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  141. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  142. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  143. block_tables=self.block_tables[:self.num_prefills],
  144. use_cuda_graph=False,
  145. )
  146. return self._cached_prefill_metadata
  147. @property
  148. def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
  149. if self.num_decode_tokens == 0:
  150. return None
  151. if self._cached_decode_metadata is not None:
  152. return self._cached_decode_metadata
  153. assert self.block_tables is not None
  154. assert self.seq_lens_tensor is not None
  155. self._cached_decode_metadata = FlashAttentionMetadata(
  156. num_prefills=0,
  157. num_prefill_tokens=0,
  158. num_decode_tokens=self.num_decode_tokens,
  159. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  160. seq_lens=None,
  161. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  162. max_query_len=None,
  163. max_prefill_seq_len=0,
  164. max_decode_seq_len=self.max_decode_seq_len,
  165. query_start_loc=None,
  166. seq_start_loc=None,
  167. context_lens_tensor=None,
  168. block_tables=self.block_tables[self.num_prefills:],
  169. use_cuda_graph=self.use_cuda_graph,
  170. )
  171. return self._cached_decode_metadata
  172. class FlashAttentionMetadataBuilder(
  173. AttentionMetadataBuilder[FlashAttentionMetadata]):
  174. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  175. self.slot_mapping: List[int] = []
  176. self.prefill_seq_lens: List[int] = []
  177. self.context_lens: List[int] = []
  178. self.block_tables: List[List[int]] = []
  179. self.curr_seq_lens: List[int] = []
  180. self.num_prefills = 0
  181. self.num_prefill_tokens = 0
  182. self.num_decode_tokens = 0
  183. self.has_prefix_cache_hit = False
  184. self.input_builder = input_builder
  185. self.runner = input_builder.runner
  186. self.sliding_window = input_builder.sliding_window
  187. self.block_size = input_builder.block_size
  188. self.use_v2_block_manager = (
  189. input_builder.scheduler_config.use_v2_block_manager)
  190. def _add_seq_group(
  191. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  192. chunked_prefill_enabled: bool, prefix_cache_hit: bool):
  193. """Add a sequence group to the metadata. Specifically update/append
  194. 1. context length.
  195. 2. block table.
  196. 3. slot mapping.
  197. """
  198. is_prompt = inter_data.is_prompt
  199. block_tables = inter_data.block_tables
  200. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  201. curr_sliding_window_block) in zip(
  202. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  203. inter_data.orig_seq_lens, inter_data.seq_lens,
  204. inter_data.query_lens, inter_data.context_lens,
  205. inter_data.curr_sliding_window_blocks):
  206. self.context_lens.append(context_len)
  207. if is_prompt:
  208. self.num_prefills += 1
  209. self.num_prefill_tokens += token_len
  210. self.prefill_seq_lens.append(seq_len)
  211. else:
  212. assert query_len == 1, (
  213. "seq_len: {}, context_len: {}, query_len: {}".format(
  214. seq_len, context_len, query_len))
  215. self.num_decode_tokens += query_len
  216. self.curr_seq_lens.append(curr_seq_len)
  217. # Compute block table.
  218. # TODO: Combine chunked prefill and prefix caching by
  219. # only allowing multiple of block_size chunk size.
  220. # NOTE: This only works for oooooooxxx style attention.
  221. block_table = []
  222. if prefix_cache_hit:
  223. # NOTE: For flash-attn, the block table should
  224. # include the entries for the incoming prefill tokens.
  225. block_table = block_tables[seq_id]
  226. elif ((chunked_prefill_enabled or not is_prompt)
  227. and block_tables is not None):
  228. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  229. self.block_tables.append(block_table)
  230. # Compute slot mapping.
  231. is_profile_run = is_block_tables_empty(block_tables)
  232. start_idx = compute_slot_mapping_start_idx(
  233. is_prompt, query_len, context_len, self.sliding_window,
  234. self.use_v2_block_manager)
  235. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  236. seq_len, context_len, start_idx,
  237. self.block_size, inter_data.block_tables)
  238. def build(self, seq_lens: List[int], query_lens: List[int],
  239. cuda_graph_pad_size: int, batch_size: int):
  240. """Build attention metadata with on-device tensors.
  241. Args:
  242. seq_lens: The maybe padded sequence lengths of the input sequences.
  243. query_lens: The query lengths of the input sequences.
  244. cuda_graph_pad_size: The padding size for cuda graph.
  245. -1 if cuda graph is not used.
  246. batch_size: The maybe padded batch size.
  247. """
  248. prefix_cache_hit = any([
  249. inter_data.prefix_cache_hit
  250. for inter_data in self.input_builder.inter_data_list
  251. ])
  252. for inter_data in self.input_builder.inter_data_list:
  253. self._add_seq_group(inter_data,
  254. self.input_builder.chunked_prefill_enabled,
  255. prefix_cache_hit)
  256. device = self.runner.device
  257. use_captured_graph = cuda_graph_pad_size != -1
  258. max_query_len = max(query_lens)
  259. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  260. max_decode_seq_len = max(self.curr_seq_lens, default=0)
  261. num_decode_tokens = self.num_decode_tokens
  262. if use_captured_graph:
  263. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  264. self.block_tables.extend([] * cuda_graph_pad_size)
  265. num_decode_tokens = batch_size
  266. # The shape of graph_block_tables is
  267. # [max batch size, max context len // block size].
  268. input_block_tables = self.runner.graph_block_tables[:batch_size]
  269. for i, block_table in enumerate(self.block_tables):
  270. if block_table:
  271. input_block_tables[i, :len(block_table)] = block_table
  272. block_tables = torch.from_numpy(input_block_tables).to(
  273. device=device, non_blocking=True)
  274. else:
  275. block_tables = make_tensor_with_pad(
  276. self.block_tables,
  277. pad=0,
  278. dtype=torch.int,
  279. device=device,
  280. )
  281. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  282. assert device is not None
  283. context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
  284. device, self.runner.pin_memory)
  285. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  286. self.runner.pin_memory)
  287. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  288. self.runner.pin_memory)
  289. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  290. device, self.runner.pin_memory)
  291. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  292. dtype=torch.int32,
  293. device=device)
  294. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  295. dtype=torch.int32,
  296. device=device)
  297. torch.cumsum(seq_lens_tensor,
  298. dim=0,
  299. dtype=seq_start_loc.dtype,
  300. out=seq_start_loc[1:])
  301. torch.cumsum(query_lens_tensor,
  302. dim=0,
  303. dtype=query_start_loc.dtype,
  304. out=query_start_loc[1:])
  305. return FlashAttentionMetadata(
  306. num_prefills=self.num_prefills,
  307. slot_mapping=slot_mapping_tensor,
  308. num_prefill_tokens=self.num_prefill_tokens,
  309. num_decode_tokens=num_decode_tokens,
  310. seq_lens=seq_lens,
  311. seq_lens_tensor=seq_lens_tensor,
  312. max_query_len=max_query_len,
  313. max_prefill_seq_len=max_prefill_seq_len,
  314. max_decode_seq_len=max_decode_seq_len,
  315. query_start_loc=query_start_loc,
  316. seq_start_loc=seq_start_loc,
  317. context_lens_tensor=context_lens_tensor,
  318. block_tables=block_tables,
  319. use_cuda_graph=use_captured_graph,
  320. )
  321. class FlashAttentionImpl(AttentionImpl):
  322. """
  323. If the input tensors contain prompt tokens, the layout is as follows:
  324. |<--------------- num_prefill_tokens ----------------->|
  325. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  326. Otherwise, the layout is as follows:
  327. |<----------------- num_decode_tokens ------------------>|
  328. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  329. Generation tokens can contain padding when cuda-graph is used.
  330. Currently, prompt tokens don't contain any padding.
  331. The prompts might have different lengths, while the generation tokens
  332. always have length 1.
  333. If chunked prefill is enabled, prefill tokens and decode tokens can be
  334. batched together in a flattened 1D query.
  335. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  336. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  337. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  338. padding between prefill and decode tokens.
  339. """
  340. def __init__(
  341. self,
  342. num_heads: int,
  343. head_size: int,
  344. scale: float,
  345. num_kv_heads: int,
  346. alibi_slopes: Optional[List[float]],
  347. sliding_window: Optional[int],
  348. kv_cache_dtype: str,
  349. blocksparse_params: Optional[Dict[str, Any]] = None,
  350. logits_soft_cap: Optional[float] = None,
  351. ) -> None:
  352. if blocksparse_params is not None:
  353. raise ValueError(
  354. "FlashAttention does not support block-sparse attention.")
  355. self.num_heads = num_heads
  356. self.head_size = head_size
  357. self.scale = float(scale)
  358. self.num_kv_heads = num_kv_heads
  359. if alibi_slopes is not None:
  360. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  361. self.alibi_slopes = alibi_slopes
  362. self.sliding_window = ((sliding_window, sliding_window)
  363. if sliding_window is not None else (-1, -1))
  364. self.kv_cache_dtype = kv_cache_dtype
  365. if logits_soft_cap is None:
  366. # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
  367. logits_soft_cap = 0
  368. self.logits_soft_cap = logits_soft_cap
  369. assert self.num_heads % self.num_kv_heads == 0
  370. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  371. if sliding_window is not None:
  372. # NOTE: flash-attn's sliding window does not work with
  373. # paged KV cache.
  374. raise ValueError(
  375. "Sliding window is not supported in FlashAttention.")
  376. support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
  377. if head_size not in support_head_sizes:
  378. raise ValueError(
  379. f"Head size {head_size} is not supported by FlashAttention. "
  380. f"Supported head sizes are: {support_head_sizes}.")
  381. def forward(
  382. self,
  383. query: torch.Tensor,
  384. key: torch.Tensor,
  385. value: torch.Tensor,
  386. kv_cache: torch.Tensor,
  387. attn_metadata: FlashAttentionMetadata,
  388. k_scale: float = 1.0,
  389. v_scale: float = 1.0,
  390. attn_type: AttentionType = AttentionType.DECODER,
  391. ) -> torch.Tensor:
  392. """Forward pass with FlashAttention.
  393. Args:
  394. query: shape = [num_tokens, num_heads * head_size]
  395. key: shape = [num_tokens, num_kv_heads * head_size]
  396. value: shape = [num_tokens, num_kv_heads * head_size]
  397. kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
  398. attn_metadata: Metadata for attention.
  399. Returns:
  400. shape = [num_tokens, num_heads * head_size]
  401. """
  402. if attn_type != AttentionType.DECODER:
  403. raise NotImplementedError("Encoder self-attention and "
  404. "encoder/decoder cross-attention "
  405. "are not implemented for "
  406. "FlashAttentionImpl")
  407. # NOTE: FlashAttention does not support FP8 KV cache.
  408. assert k_scale == 1.0 and v_scale == 1.0, (
  409. "key/v_scale is not supported in FlashAttention.")
  410. num_tokens, hidden_size = query.shape
  411. # Reshape the query, key, and value tensors.
  412. query = query.view(-1, self.num_heads, self.head_size)
  413. key = key.view(-1, self.num_kv_heads, self.head_size)
  414. value = value.view(-1, self.num_kv_heads, self.head_size)
  415. if kv_cache is not None:
  416. key_cache = kv_cache[0]
  417. value_cache = kv_cache[1]
  418. # Reshape the input keys and values and store them in the cache.
  419. # If kv_cache is not provided, the new key and value tensors are
  420. # not cached. This happens during the initial memory profiling run.
  421. ops.reshape_and_cache_flash(
  422. key,
  423. value,
  424. key_cache,
  425. value_cache,
  426. attn_metadata.slot_mapping.flatten(),
  427. self.kv_cache_dtype,
  428. k_scale,
  429. v_scale,
  430. )
  431. num_prefill_tokens = attn_metadata.num_prefill_tokens
  432. num_decode_tokens = attn_metadata.num_decode_tokens
  433. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  434. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  435. output = torch.empty_like(query)
  436. # Query for decode. KV is not needed because it is already cached.
  437. decode_query = query[num_prefill_tokens:]
  438. # QKV for prefill.
  439. query = query[:num_prefill_tokens]
  440. key = key[:num_prefill_tokens]
  441. value = value[:num_prefill_tokens]
  442. assert query.shape[0] == num_prefill_tokens
  443. assert decode_query.shape[0] == num_decode_tokens
  444. if prefill_meta := attn_metadata.prefill_metadata:
  445. # Prompt run.
  446. if (kv_cache is None or prefill_meta.block_tables is None
  447. or prefill_meta.block_tables.numel() == 0):
  448. # normal attention
  449. # When block_tables are not filled, it means q and k are the
  450. # prompt, and they have the same length.
  451. out = flash_attn_varlen_func(
  452. q=query,
  453. k=key,
  454. v=value,
  455. cu_seqlens_q=prefill_meta.seq_start_loc,
  456. cu_seqlens_k=prefill_meta.seq_start_loc,
  457. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  458. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  459. softmax_scale=self.scale,
  460. causal=True,
  461. window_size=self.sliding_window,
  462. alibi_slopes=self.alibi_slopes,
  463. softcap=self.logits_soft_cap,
  464. )
  465. assert output[:num_prefill_tokens].shape == out.shape
  466. output[:num_prefill_tokens] = out
  467. else:
  468. # prefix-enabled attention
  469. assert prefill_meta.seq_lens is not None
  470. max_seq_len = max(prefill_meta.seq_lens)
  471. output[:num_prefill_tokens] = flash_attn_varlen_func(
  472. q=query,
  473. k=key_cache,
  474. v=value_cache,
  475. cu_seqlens_q=prefill_meta.query_start_loc,
  476. max_seqlen_q=prefill_meta.max_query_len,
  477. cu_seqlens_k=prefill_meta.seq_start_loc,
  478. max_seqlen_k=max_seq_len,
  479. softmax_scale=self.scale,
  480. causal=True,
  481. alibi_slopes=self.alibi_slopes,
  482. block_table=prefill_meta.block_tables,
  483. softcap=self.logits_soft_cap,
  484. )
  485. if decode_meta := attn_metadata.decode_metadata:
  486. # Decoding run.
  487. output[num_prefill_tokens:] = flash_attn_with_kvcache(
  488. decode_query.unsqueeze(1),
  489. key_cache,
  490. value_cache,
  491. block_table=decode_meta.block_tables,
  492. cache_seqlens=decode_meta.seq_lens_tensor,
  493. softmax_scale=self.scale,
  494. causal=True,
  495. alibi_slopes=self.alibi_slopes,
  496. ).squeeze(1)
  497. # Reshape the output tensor.
  498. return output.view(num_tokens, hidden_size)