rocm_flash_attn.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. """Attention layer ROCm GPUs."""
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from loguru import logger
  6. import aphrodite.common.envs as envs
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.attention.backends.abstract import (AttentionBackend,
  9. AttentionImpl,
  10. AttentionMetadata,
  11. AttentionType)
  12. from aphrodite.attention.backends.utils import (CommonAttentionState,
  13. CommonMetadataBuilder)
  14. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  15. PagedAttentionMetadata)
  16. if TYPE_CHECKING:
  17. from aphrodite.worker.model_runner import (
  18. ModelInputForGPUWithSamplingMetadata)
  19. _PARTITION_SIZE_ROCM = 512
  20. _ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
  21. class ROCmFlashAttentionBackend(AttentionBackend):
  22. @staticmethod
  23. def get_name() -> str:
  24. return "rocm-flash-attn"
  25. @staticmethod
  26. def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
  27. return ROCmFlashAttentionImpl
  28. @staticmethod
  29. def get_metadata_cls() -> Type["AttentionMetadata"]:
  30. return ROCmFlashAttentionMetadata
  31. @staticmethod
  32. def get_state_cls() -> Type["CommonAttentionState"]:
  33. return CommonAttentionState
  34. @staticmethod
  35. def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
  36. return ROCmFlashAttentionMetadataBuilder
  37. @staticmethod
  38. def get_kv_cache_shape(
  39. num_blocks: int,
  40. block_size: int,
  41. num_kv_heads: int,
  42. head_size: int,
  43. ) -> Tuple[int, ...]:
  44. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  45. num_kv_heads, head_size)
  46. @staticmethod
  47. def swap_blocks(
  48. src_kv_cache: torch.Tensor,
  49. dst_kv_cache: torch.Tensor,
  50. src_to_dst: torch.Tensor,
  51. ) -> None:
  52. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  53. @staticmethod
  54. def copy_blocks(
  55. kv_caches: List[torch.Tensor],
  56. src_to_dists: torch.Tensor,
  57. ) -> None:
  58. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  59. @dataclass
  60. class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
  61. """Metadata for FlashAttentionBackend.
  62. NOTE: Any python object stored here is not updated when it is
  63. cuda-graph replayed. If you have values that need to be changed
  64. dynamically, it should be stored in tensor. The tensor has to be
  65. updated from `CUDAGraphRunner.forward` API.
  66. """
  67. # (batch_size,). The sequence length per sequence. Sequence length means
  68. # the computed tokens + new tokens None if it is a decoding.
  69. seq_lens: Optional[List[int]]
  70. # seq_lens stored as a tensor.
  71. seq_lens_tensor: Optional[torch.Tensor]
  72. # NOTE: Definition of context_len, query_len, and seq_len.
  73. # |---------- N-1 iteration --------|
  74. # |---------------- N iteration ---------------------|
  75. # |- tokenA -|......................|-- newTokens ---|
  76. # |---------- context_len ----------|
  77. # |-------------------- seq_len ----------------------|
  78. # |-- query_len ---|
  79. # Maximum query length in the batch. None for decoding.
  80. max_query_len: Optional[int]
  81. # Maximum sequence length among prefill batch. 0 if there are decoding
  82. # requests only.
  83. max_prefill_seq_len: int
  84. # Maximum sequence length among decode batch. 0 if there are prefill
  85. # requests only.
  86. max_decode_seq_len: int
  87. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  88. # the batch, used to index into subquery. E.g., if the subquery length
  89. # is [4, 6], it is [0, 4, 10].
  90. query_start_loc: Optional[torch.Tensor]
  91. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  92. # the batch, used to index into sequence. E.g., if the sequence length is
  93. # [4, 6], it is [0, 4, 10].
  94. seq_start_loc: Optional[torch.Tensor]
  95. # Whether or not if cuda graph is enabled.
  96. # Cuda-graph is currently enabled for decoding only.
  97. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  98. use_cuda_graph: bool
  99. # (batch_size,) A tensor of context lengths (tokens that are computed
  100. # so far).
  101. context_lens_tensor: Optional[torch.Tensor]
  102. _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  103. _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  104. @property
  105. def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  106. if self.num_prefills == 0:
  107. return None
  108. if self._cached_prefill_metadata is not None:
  109. return self._cached_prefill_metadata
  110. assert self.seq_lens is not None
  111. assert self.seq_lens_tensor is not None
  112. assert self.query_start_loc is not None
  113. assert self.context_lens_tensor is not None
  114. assert self.block_tables is not None
  115. assert self.seq_start_loc is not None
  116. self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
  117. num_prefills=self.num_prefills,
  118. num_prefill_tokens=self.num_prefill_tokens,
  119. num_decode_tokens=0,
  120. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  121. seq_lens=self.seq_lens[:self.num_prefills],
  122. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  123. max_query_len=self.max_query_len,
  124. max_prefill_seq_len=self.max_prefill_seq_len,
  125. max_decode_seq_len=0,
  126. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  127. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  128. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  129. block_tables=self.block_tables[:self.num_prefills],
  130. use_cuda_graph=False,
  131. )
  132. return self._cached_prefill_metadata
  133. @property
  134. def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  135. if self.num_decode_tokens == 0:
  136. return None
  137. if self._cached_decode_metadata is not None:
  138. return self._cached_decode_metadata
  139. assert self.block_tables is not None
  140. assert self.seq_lens_tensor is not None
  141. self._cached_decode_metadata = ROCmFlashAttentionMetadata(
  142. num_prefills=0,
  143. num_prefill_tokens=0,
  144. num_decode_tokens=self.num_decode_tokens,
  145. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  146. seq_lens=None,
  147. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  148. max_query_len=None,
  149. max_prefill_seq_len=0,
  150. max_decode_seq_len=self.max_decode_seq_len,
  151. query_start_loc=None,
  152. seq_start_loc=None,
  153. context_lens_tensor=None,
  154. block_tables=self.block_tables[self.num_prefills:],
  155. use_cuda_graph=self.use_cuda_graph,
  156. )
  157. return self._cached_decode_metadata
  158. def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
  159. sampled_token_ids: Optional[torch.Tensor],
  160. block_size: int, num_seqs: int, num_queries: int):
  161. """
  162. Update metadata in-place to advance one decode step.
  163. """
  164. # When using cudagraph, the num_seqs is padded to the next captured
  165. # batch sized, but num_queries tracks the actual number of requests in
  166. # the batch. For --enforce-eager mode, num_seqs == num_queries
  167. if num_seqs != num_queries:
  168. assert num_seqs > num_queries
  169. assert self.use_cuda_graph
  170. assert self.num_prefills == 0
  171. assert self.num_prefill_tokens == 0
  172. assert self.num_decode_tokens == num_seqs
  173. assert self.slot_mapping.shape == (num_seqs, )
  174. assert self.seq_lens is not None
  175. assert len(self.seq_lens) == num_seqs
  176. assert self.seq_lens_tensor is not None
  177. assert self.seq_lens_tensor.shape == (num_seqs, )
  178. assert self.max_query_len == 1
  179. assert self.max_prefill_seq_len == 0
  180. assert self.max_decode_seq_len == max(self.seq_lens)
  181. assert self.query_start_loc is not None
  182. assert self.query_start_loc.shape == (num_queries + 1, )
  183. assert self.seq_start_loc is not None
  184. assert self.seq_start_loc.shape == (num_seqs + 1, )
  185. assert self.context_lens_tensor is not None
  186. assert self.context_lens_tensor.shape == (num_queries, )
  187. assert self.block_tables is not None
  188. assert self.block_tables.shape[0] == num_seqs
  189. # Update query lengths. Note that we update only queries and not seqs,
  190. # since tensors may be padded due to captured cuda graph batch size
  191. for i in range(num_queries):
  192. self.seq_lens[i] += 1
  193. self.max_decode_seq_len = max(self.seq_lens)
  194. ops.advance_step_flashattn(num_seqs=num_seqs,
  195. num_queries=num_queries,
  196. block_size=block_size,
  197. input_tokens=model_input.input_tokens,
  198. sampled_token_ids=sampled_token_ids,
  199. input_positions=model_input.input_positions,
  200. seq_lens=self.seq_lens_tensor,
  201. slot_mapping=self.slot_mapping,
  202. block_tables=self.block_tables)
  203. class ROCmFlashAttentionMetadataBuilder(
  204. CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
  205. _metadata_cls = ROCmFlashAttentionMetadata
  206. def _make_alibi_bias(alibi_slopes: torch.Tensor,
  207. dtype: torch.dtype,
  208. seq_lens: Optional[List[int]],
  209. make_attn_mask: bool = True) -> List[torch.Tensor]:
  210. attn_biases = []
  211. if seq_lens:
  212. for seq_len in seq_lens:
  213. bias = torch.arange(seq_len, dtype=dtype)
  214. # NOTE(zhuohan): HF uses
  215. # `bias = bias[None, :].repeat(seq_len, 1)`
  216. # here. We find that both biases give the same results, but
  217. # the bias below more accurately follows the original ALiBi
  218. # paper.
  219. bias = bias[None, :] - bias[:, None]
  220. num_heads = alibi_slopes.shape[0]
  221. bias = bias[None, :].repeat(
  222. (num_heads, 1, 1)).to(alibi_slopes.device)
  223. bias.mul_(alibi_slopes[:, None, None])
  224. if make_attn_mask:
  225. inf_mask = torch.empty(
  226. (1, seq_len, seq_len),
  227. dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
  228. alibi_slopes.device)
  229. attn_biases.append((bias + inf_mask).to(dtype))
  230. else:
  231. attn_biases.append(bias.to(dtype))
  232. return attn_biases
  233. class ROCmFlashAttentionImpl(AttentionImpl):
  234. """
  235. If the input tensors contain prompt tokens, the layout is as follows:
  236. |<--------------- num_prompt_tokens -------------->|
  237. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  238. Otherwise, the layout is as follows:
  239. |<------------------ num_generation_tokens (M) ----------------->|
  240. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  241. Generation tokens can contain padding when cuda-graph is used.
  242. Currently, prompt tokens don't contain any padding.
  243. The prompts might have different lengths, while the generation tokens
  244. always have length 1.
  245. If chunked prefill is enabled, prefill tokens and decode tokens can be
  246. batched together in a flattened 1D query.
  247. |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
  248. |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
  249. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  250. padding between prefill and decode tokens.
  251. """
  252. def __init__(
  253. self,
  254. num_heads: int,
  255. head_size: int,
  256. scale: float,
  257. num_kv_heads: int,
  258. alibi_slopes: Optional[List[float]],
  259. sliding_window: Optional[int],
  260. kv_cache_dtype: str,
  261. blocksparse_params: Optional[Dict[str, Any]] = None,
  262. logits_soft_cap: Optional[float] = None,
  263. ) -> None:
  264. if blocksparse_params is not None:
  265. raise ValueError(
  266. "ROCmFlashAttention does not support blocksparse attention.")
  267. if logits_soft_cap is not None:
  268. raise ValueError(
  269. "ROCmFlashAttention does not support attention logits soft "
  270. "capping.")
  271. self.num_heads = num_heads
  272. self.head_size = head_size
  273. self.scale = float(scale)
  274. self.num_kv_heads = num_kv_heads
  275. if alibi_slopes is not None:
  276. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  277. self.alibi_slopes = alibi_slopes
  278. self.sliding_window = ((sliding_window, sliding_window)
  279. if sliding_window is not None else (-1, -1))
  280. self.kv_cache_dtype = kv_cache_dtype
  281. assert self.num_heads % self.num_kv_heads == 0
  282. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  283. supported_head_sizes = PagedAttention.get_supported_head_sizes()
  284. if head_size not in supported_head_sizes:
  285. raise ValueError(
  286. f"Head size {head_size} is not supported by PagedAttention. "
  287. f"Supported head sizes are: {supported_head_sizes}.")
  288. self.use_naive_attn = False
  289. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
  290. self.use_triton_flash_attn = envs.APHRODITE_USE_TRITON_FLASH_ATTN
  291. if self.use_triton_flash_attn:
  292. from aphrodite.attention.ops.triton_flash_attn import ( # noqa: F401
  293. triton_attention)
  294. self.attn_func = triton_attention
  295. logger.debug("Using Triton FA in ROCmBackend")
  296. if self.sliding_window != (-1, -1):
  297. logger.warning("ROCm Triton FA does not currently support "
  298. "sliding window attention. If using half "
  299. "precision, please try using the ROCm CK "
  300. "FA backend instead by setting the env var "
  301. "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
  302. else:
  303. # if not using triton, navi3x/navi21/navi10 do not use flash-attn
  304. # either
  305. if torch.cuda.get_device_capability()[0] != 9:
  306. self.use_naive_attn = True
  307. else:
  308. try:
  309. from flash_attn import flash_attn_varlen_func # noqa: F401
  310. self.attn_func = flash_attn_varlen_func
  311. logger.debug("Using CK FA in ROCmBackend")
  312. except ModuleNotFoundError:
  313. self.use_naive_attn = True
  314. if self.use_naive_attn:
  315. self.attn_func = _sdpa_attention
  316. logger.debug("Using naive attention in ROCmBackend")
  317. def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
  318. """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
  319. tokens, n_kv_heads, head_dim = x.shape
  320. return (x[:, :,
  321. None, :].expand(tokens, n_kv_heads, n_rep,
  322. head_dim).reshape(tokens, n_kv_heads * n_rep,
  323. head_dim))
  324. def forward(
  325. self,
  326. query: torch.Tensor,
  327. key: torch.Tensor,
  328. value: torch.Tensor,
  329. kv_cache: torch.Tensor,
  330. attn_metadata: ROCmFlashAttentionMetadata,
  331. k_scale: float = 1.0,
  332. v_scale: float = 1.0,
  333. attn_type: AttentionType = AttentionType.DECODER,
  334. ) -> torch.Tensor:
  335. """Forward pass with FlashAttention and PagedAttention.
  336. Args:
  337. query: shape = [num_tokens, num_heads * head_size]
  338. key: shape = [num_tokens, num_kv_heads * head_size]
  339. value: shape = [num_tokens, num_kv_heads * head_size]
  340. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  341. attn_metadata: Metadata for attention.
  342. Returns:
  343. shape = [num_tokens, num_heads * head_size]
  344. """
  345. if attn_type != AttentionType.DECODER:
  346. raise NotImplementedError("Encoder self-attention and "
  347. "encoder/decoder cross-attention "
  348. "are not implemented for "
  349. "ROCmFlashAttentionImpl")
  350. num_tokens, hidden_size = query.shape
  351. # Reshape the query, key, and value tensors.
  352. query = query.view(-1, self.num_heads, self.head_size)
  353. key = key.view(-1, self.num_kv_heads, self.head_size)
  354. value = value.view(-1, self.num_kv_heads, self.head_size)
  355. if kv_cache is not None:
  356. key_cache, value_cache = PagedAttention.split_kv_cache(
  357. kv_cache, self.num_kv_heads, self.head_size)
  358. # Reshape the input keys and values and store them in the cache.
  359. # If kv_cache is not provided, the new key and value tensors are
  360. # not cached. This happens during the initial memory profiling run.
  361. PagedAttention.write_to_paged_cache(
  362. key,
  363. value,
  364. key_cache,
  365. value_cache,
  366. attn_metadata.slot_mapping,
  367. self.kv_cache_dtype,
  368. k_scale,
  369. v_scale,
  370. )
  371. num_prefill_tokens = attn_metadata.num_prefill_tokens
  372. num_decode_tokens = attn_metadata.num_decode_tokens
  373. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  374. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  375. output = torch.empty_like(query)
  376. # Query for decode. KV is not needed because it is already cached.
  377. decode_query = query[num_prefill_tokens:]
  378. # QKV for prefill.
  379. query = query[:num_prefill_tokens]
  380. key = key[:num_prefill_tokens]
  381. value = value[:num_prefill_tokens]
  382. assert query.shape[0] == num_prefill_tokens
  383. assert decode_query.shape[0] == num_decode_tokens
  384. if prefill_meta := attn_metadata.prefill_metadata:
  385. # Prompt run.
  386. assert prefill_meta.seq_lens is not None
  387. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  388. # triton attention
  389. # When block_tables are not filled, it means q and k are the
  390. # prompt, and they have the same length.
  391. attn_masks = None
  392. if self.use_triton_flash_attn:
  393. if self.alibi_slopes is not None:
  394. attn_masks = _make_alibi_bias(
  395. self.alibi_slopes,
  396. query.dtype,
  397. attn_metadata.seq_lens,
  398. make_attn_mask=False) # type: ignore
  399. out, _ = self.attn_func(
  400. query,
  401. key,
  402. value,
  403. None,
  404. prefill_meta.seq_start_loc,
  405. prefill_meta.seq_start_loc,
  406. prefill_meta.max_prefill_seq_len,
  407. prefill_meta.max_prefill_seq_len,
  408. True,
  409. self.scale,
  410. attn_masks[0][None]
  411. if attn_masks is not None else None,
  412. )
  413. elif self.use_naive_attn:
  414. if self.num_kv_heads != self.num_heads:
  415. # Interleave for MQA workaround.
  416. key = self.repeat_kv(key, self.num_queries_per_kv)
  417. value = self.repeat_kv(value, self.num_queries_per_kv)
  418. if self.alibi_slopes is not None:
  419. attn_masks = _make_alibi_bias(
  420. self.alibi_slopes,
  421. query.dtype,
  422. attn_metadata.seq_lens,
  423. make_attn_mask=True) # type: ignore
  424. query = query.movedim(0, query.dim() - 2)
  425. key = key.movedim(0, key.dim() - 2)
  426. value = value.movedim(0, value.dim() - 2)
  427. # sdpa math backend attention
  428. out = self.attn_func(
  429. query,
  430. key,
  431. value,
  432. prefill_meta.seq_lens,
  433. num_tokens,
  434. self.num_heads,
  435. self.head_size,
  436. self.scale,
  437. attn_masks,
  438. )
  439. else:
  440. out = self.attn_func(
  441. q=query,
  442. k=key,
  443. v=value,
  444. cu_seqlens_q=prefill_meta.seq_start_loc,
  445. cu_seqlens_k=prefill_meta.seq_start_loc,
  446. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  447. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  448. softmax_scale=self.scale,
  449. causal=True,
  450. window_size=self.sliding_window,
  451. alibi_slopes=self.alibi_slopes,
  452. )
  453. # common code for prefill
  454. assert output[:num_prefill_tokens].shape == out.shape
  455. output[:num_prefill_tokens] = out
  456. else:
  457. # prefix-enabled attention
  458. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  459. query,
  460. key,
  461. value,
  462. self.kv_cache_dtype,
  463. key_cache,
  464. value_cache,
  465. prefill_meta.block_tables,
  466. prefill_meta.query_start_loc,
  467. prefill_meta.seq_lens_tensor,
  468. prefill_meta.context_lens_tensor,
  469. prefill_meta.max_query_len,
  470. self.alibi_slopes,
  471. self.sliding_window[0],
  472. k_scale,
  473. v_scale,
  474. )
  475. if decode_meta := attn_metadata.decode_metadata:
  476. # Decoding run.
  477. # Whether to use rocm custom paged attention or not
  478. num_seqs, num_heads, head_size = decode_query.shape
  479. block_size = value_cache.shape[3]
  480. gqa_ratio = num_heads // self.num_kv_heads
  481. use_custom = _use_rocm_custom_paged_attention(
  482. decode_query.dtype, head_size, block_size, gqa_ratio,
  483. decode_meta.max_decode_seq_len)
  484. if use_custom:
  485. max_seq_len = decode_meta.max_decode_seq_len
  486. max_num_partitions = (
  487. (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
  488. _PARTITION_SIZE_ROCM)
  489. assert _PARTITION_SIZE_ROCM % block_size == 0
  490. tmp_output = torch.empty(
  491. size=(num_seqs, num_heads, max_num_partitions, head_size),
  492. dtype=output.dtype,
  493. device=output.device,
  494. )
  495. exp_sums = torch.empty(
  496. size=(num_seqs, num_heads, max_num_partitions),
  497. dtype=torch.float32,
  498. device=output.device,
  499. )
  500. max_logits = torch.empty_like(exp_sums)
  501. ops.paged_attention_rocm(
  502. output[num_prefill_tokens:],
  503. exp_sums,
  504. max_logits,
  505. tmp_output,
  506. decode_query,
  507. key_cache,
  508. value_cache,
  509. self.num_kv_heads,
  510. self.scale,
  511. decode_meta.block_tables,
  512. decode_meta.seq_lens_tensor,
  513. block_size,
  514. max_seq_len,
  515. self.alibi_slopes,
  516. self.kv_cache_dtype,
  517. k_scale,
  518. v_scale,
  519. )
  520. else:
  521. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  522. decode_query,
  523. key_cache,
  524. value_cache,
  525. decode_meta.block_tables,
  526. decode_meta.seq_lens_tensor,
  527. decode_meta.max_decode_seq_len,
  528. self.kv_cache_dtype,
  529. self.num_kv_heads,
  530. self.scale,
  531. self.alibi_slopes,
  532. k_scale,
  533. v_scale,
  534. )
  535. # Reshape the output tensor.
  536. return output.view(num_tokens, hidden_size)
  537. def _sdpa_attention(
  538. query: torch.Tensor,
  539. key: torch.Tensor,
  540. value: torch.Tensor,
  541. seq_lens: List[int],
  542. num_tokens: int,
  543. num_heads: int,
  544. head_size: int,
  545. scale: float,
  546. attn_masks: Optional[List[torch.Tensor]] = None,
  547. ) -> torch.Tensor:
  548. start = 0
  549. output = torch.empty((num_tokens, num_heads, head_size),
  550. dtype=query.dtype,
  551. device=query.device)
  552. for i, seq_len in enumerate(seq_lens):
  553. end = start + seq_len
  554. with torch.backends.cuda.sdp_kernel(enable_math=True,
  555. enable_flash=False,
  556. enable_mem_efficient=False):
  557. sub_out = torch.nn.functional.scaled_dot_product_attention(
  558. query[:, start:end, :],
  559. key[:, start:end, :],
  560. value[:, start:end, :],
  561. dropout_p=0.0,
  562. is_causal=attn_masks is None,
  563. attn_mask=attn_masks[i] if attn_masks else None,
  564. scale=scale).movedim(query.dim() - 2, 0)
  565. output[start:end, :, :] = sub_out
  566. start = end
  567. return output
  568. def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
  569. block_size: int, gqa_ratio: int,
  570. max_seq_len: int) -> bool:
  571. # rocm custom page attention not support on navi (gfx1*)
  572. return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
  573. and (head_size == 64 or head_size == 128)
  574. and (block_size == 16 or block_size == 32)
  575. and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)