rocm_flash_attn.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. """Attention layer ROCm GPUs."""
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from loguru import logger
  6. import aphrodite.common.envs as envs
  7. from aphrodite import _custom_ops as ops
  8. from aphrodite.attention.backends.abstract import (AttentionBackend,
  9. AttentionImpl,
  10. AttentionMetadata,
  11. AttentionType)
  12. from aphrodite.attention.backends.utils import (CommonAttentionState,
  13. CommonMetadataBuilder)
  14. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  15. PagedAttentionMetadata)
  16. _PARTITION_SIZE = 256
  17. ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
  18. class ROCmFlashAttentionBackend(AttentionBackend):
  19. @staticmethod
  20. def get_name() -> str:
  21. return "rocm-flash-attn"
  22. @staticmethod
  23. def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
  24. return ROCmFlashAttentionImpl
  25. @staticmethod
  26. def get_metadata_cls() -> Type["AttentionMetadata"]:
  27. return ROCmFlashAttentionMetadata
  28. @staticmethod
  29. def get_state_cls() -> Type["CommonAttentionState"]:
  30. return CommonAttentionState
  31. @staticmethod
  32. def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
  33. return ROCmFlashAttentionMetadataBuilder
  34. @staticmethod
  35. def get_kv_cache_shape(
  36. num_blocks: int,
  37. block_size: int,
  38. num_kv_heads: int,
  39. head_size: int,
  40. ) -> Tuple[int, ...]:
  41. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  42. num_kv_heads, head_size)
  43. @staticmethod
  44. def swap_blocks(
  45. src_kv_cache: torch.Tensor,
  46. dst_kv_cache: torch.Tensor,
  47. src_to_dst: torch.Tensor,
  48. ) -> None:
  49. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  50. @staticmethod
  51. def copy_blocks(
  52. kv_caches: List[torch.Tensor],
  53. src_to_dists: torch.Tensor,
  54. ) -> None:
  55. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  56. @dataclass
  57. class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
  58. """Metadata for FlashAttentionBackend.
  59. NOTE: Any python object stored here is not updated when it is
  60. cuda-graph replayed. If you have values that need to be changed
  61. dynamically, it should be stored in tensor. The tensor has to be
  62. updated from `CUDAGraphRunner.forward` API.
  63. """
  64. # (batch_size,). The sequence length per sequence. Sequence length means
  65. # the computed tokens + new tokens None if it is a decoding.
  66. seq_lens: Optional[List[int]]
  67. # seq_lens stored as a tensor.
  68. seq_lens_tensor: Optional[torch.Tensor]
  69. # NOTE: Definition of context_len, query_len, and seq_len.
  70. # |---------- N-1 iteration --------|
  71. # |---------------- N iteration ---------------------|
  72. # |- tokenA -|......................|-- newTokens ---|
  73. # |---------- context_len ----------|
  74. # |-------------------- seq_len ----------------------|
  75. # |-- query_len ---|
  76. # Maximum query length in the batch. None for decoding.
  77. max_query_len: Optional[int]
  78. # Maximum sequence length among prefill batch. 0 if there are decoding
  79. # requests only.
  80. max_prefill_seq_len: int
  81. # Maximum sequence length among decode batch. 0 if there are prefill
  82. # requests only.
  83. max_decode_seq_len: int
  84. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  85. # the batch, used to index into subquery. E.g., if the subquery length
  86. # is [4, 6], it is [0, 4, 10].
  87. query_start_loc: Optional[torch.Tensor]
  88. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  89. # the batch, used to index into sequence. E.g., if the sequence length is
  90. # [4, 6], it is [0, 4, 10].
  91. seq_start_loc: Optional[torch.Tensor]
  92. # Whether or not if cuda graph is enabled.
  93. # Cuda-graph is currently enabled for decoding only.
  94. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  95. use_cuda_graph: bool
  96. # (batch_size,) A tensor of context lengths (tokens that are computed
  97. # so far).
  98. context_lens_tensor: Optional[torch.Tensor]
  99. _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  100. _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  101. @property
  102. def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  103. if self.num_prefills == 0:
  104. return None
  105. if self._cached_prefill_metadata is not None:
  106. return self._cached_prefill_metadata
  107. assert self.seq_lens is not None
  108. assert self.seq_lens_tensor is not None
  109. assert self.query_start_loc is not None
  110. assert self.context_lens_tensor is not None
  111. assert self.block_tables is not None
  112. assert self.seq_start_loc is not None
  113. self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
  114. num_prefills=self.num_prefills,
  115. num_prefill_tokens=self.num_prefill_tokens,
  116. num_decode_tokens=0,
  117. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  118. seq_lens=self.seq_lens[:self.num_prefills],
  119. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  120. max_query_len=self.max_query_len,
  121. max_prefill_seq_len=self.max_prefill_seq_len,
  122. max_decode_seq_len=0,
  123. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  124. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  125. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  126. block_tables=self.block_tables[:self.num_prefills],
  127. use_cuda_graph=False,
  128. )
  129. return self._cached_prefill_metadata
  130. @property
  131. def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  132. if self.num_decode_tokens == 0:
  133. return None
  134. if self._cached_decode_metadata is not None:
  135. return self._cached_decode_metadata
  136. assert self.block_tables is not None
  137. assert self.seq_lens_tensor is not None
  138. self._cached_decode_metadata = ROCmFlashAttentionMetadata(
  139. num_prefills=0,
  140. num_prefill_tokens=0,
  141. num_decode_tokens=self.num_decode_tokens,
  142. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  143. seq_lens=None,
  144. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  145. max_query_len=None,
  146. max_prefill_seq_len=0,
  147. max_decode_seq_len=self.max_decode_seq_len,
  148. query_start_loc=None,
  149. seq_start_loc=None,
  150. context_lens_tensor=None,
  151. block_tables=self.block_tables[self.num_prefills:],
  152. use_cuda_graph=self.use_cuda_graph,
  153. )
  154. return self._cached_decode_metadata
  155. class ROCmFlashAttentionMetadataBuilder(
  156. CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
  157. _metadata_cls = ROCmFlashAttentionMetadata
  158. def _make_alibi_bias(alibi_slopes: torch.Tensor,
  159. dtype: torch.dtype,
  160. seq_lens: Optional[List[int]],
  161. make_attn_mask: bool = True) -> List[torch.Tensor]:
  162. attn_biases = []
  163. if seq_lens:
  164. for seq_len in seq_lens:
  165. bias = torch.arange(seq_len, dtype=dtype)
  166. # NOTE(zhuohan): HF uses
  167. # `bias = bias[None, :].repeat(seq_len, 1)`
  168. # here. We find that both biases give the same results, but
  169. # the bias below more accurately follows the original ALiBi
  170. # paper.
  171. bias = bias[None, :] - bias[:, None]
  172. num_heads = alibi_slopes.shape[0]
  173. bias = bias[None, :].repeat(
  174. (num_heads, 1, 1)).to(alibi_slopes.device)
  175. bias.mul_(alibi_slopes[:, None, None])
  176. if make_attn_mask:
  177. inf_mask = torch.empty(
  178. (1, seq_len, seq_len),
  179. dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
  180. alibi_slopes.device)
  181. attn_biases.append((bias + inf_mask).to(dtype))
  182. else:
  183. attn_biases.append(bias.to(dtype))
  184. return attn_biases
  185. class ROCmFlashAttentionImpl(AttentionImpl):
  186. """
  187. If the input tensors contain prompt tokens, the layout is as follows:
  188. |<--------------- num_prompt_tokens -------------->|
  189. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  190. Otherwise, the layout is as follows:
  191. |<------------------ num_generation_tokens (M) ----------------->|
  192. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  193. Generation tokens can contain padding when cuda-graph is used.
  194. Currently, prompt tokens don't contain any padding.
  195. The prompts might have different lengths, while the generation tokens
  196. always have length 1.
  197. If chunked prefill is enabled, prefill tokens and decode tokens can be
  198. batched together in a flattened 1D query.
  199. |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
  200. |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
  201. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  202. padding between prefill and decode tokens.
  203. """
  204. def __init__(
  205. self,
  206. num_heads: int,
  207. head_size: int,
  208. scale: float,
  209. num_kv_heads: int,
  210. alibi_slopes: Optional[List[float]],
  211. sliding_window: Optional[int],
  212. kv_cache_dtype: str,
  213. blocksparse_params: Optional[Dict[str, Any]] = None,
  214. logits_soft_cap: Optional[float] = None,
  215. ) -> None:
  216. if blocksparse_params is not None:
  217. raise ValueError(
  218. "ROCmFlashAttention does not support blocksparse attention.")
  219. if logits_soft_cap is not None:
  220. raise ValueError(
  221. "ROCmFlashAttention does not support attention logits soft "
  222. "capping.")
  223. self.num_heads = num_heads
  224. self.head_size = head_size
  225. self.scale = float(scale)
  226. self.num_kv_heads = num_kv_heads
  227. if alibi_slopes is not None:
  228. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  229. self.alibi_slopes = alibi_slopes
  230. self.sliding_window = ((sliding_window, sliding_window)
  231. if sliding_window is not None else (-1, -1))
  232. self.kv_cache_dtype = kv_cache_dtype
  233. assert self.num_heads % self.num_kv_heads == 0
  234. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  235. supported_head_sizes = PagedAttention.get_supported_head_sizes()
  236. if head_size not in supported_head_sizes:
  237. raise ValueError(
  238. f"Head size {head_size} is not supported by PagedAttention. "
  239. f"Supported head sizes are: {supported_head_sizes}.")
  240. self.use_naive_attn = False
  241. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
  242. self.use_triton_flash_attn = envs.APHRODITE_USE_TRITON_FLASH_ATTN
  243. if self.use_triton_flash_attn:
  244. from aphrodite.attention.ops.triton_flash_attn import ( # noqa: F401
  245. triton_attention)
  246. self.attn_func = triton_attention
  247. logger.debug("Using Triton FA in ROCmBackend")
  248. if self.sliding_window != (-1, -1):
  249. logger.warning("ROCm Triton FA does not currently support "
  250. "sliding window attention. If using half "
  251. "precision, please try using the ROCm CK "
  252. "FA backend instead by setting the env var "
  253. "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
  254. else:
  255. # if not using triton, navi3x/navi21/navi10 do not use flash-attn
  256. # either
  257. if torch.cuda.get_device_capability()[0] != 9:
  258. self.use_naive_attn = True
  259. else:
  260. try:
  261. from flash_attn import flash_attn_varlen_func # noqa: F401
  262. self.attn_func = flash_attn_varlen_func
  263. logger.debug("Using CK FA in ROCmBackend")
  264. except ModuleNotFoundError:
  265. self.use_naive_attn = True
  266. if self.use_naive_attn:
  267. self.attn_func = _sdpa_attention
  268. logger.debug("Using naive attention in ROCmBackend")
  269. def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
  270. """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
  271. tokens, n_kv_heads, head_dim = x.shape
  272. return (x[:, :,
  273. None, :].expand(tokens, n_kv_heads, n_rep,
  274. head_dim).reshape(tokens, n_kv_heads * n_rep,
  275. head_dim))
  276. def forward(
  277. self,
  278. query: torch.Tensor,
  279. key: torch.Tensor,
  280. value: torch.Tensor,
  281. kv_cache: torch.Tensor,
  282. attn_metadata: ROCmFlashAttentionMetadata,
  283. k_scale: float = 1.0,
  284. v_scale: float = 1.0,
  285. attn_type: AttentionType = AttentionType.DECODER,
  286. ) -> torch.Tensor:
  287. """Forward pass with FlashAttention and PagedAttention.
  288. Args:
  289. query: shape = [num_tokens, num_heads * head_size]
  290. key: shape = [num_tokens, num_kv_heads * head_size]
  291. value: shape = [num_tokens, num_kv_heads * head_size]
  292. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  293. attn_metadata: Metadata for attention.
  294. Returns:
  295. shape = [num_tokens, num_heads * head_size]
  296. """
  297. if attn_type != AttentionType.DECODER:
  298. raise NotImplementedError("Encoder self-attention and "
  299. "encoder/decoder cross-attention "
  300. "are not implemented for "
  301. "ROCmFlashAttentionImpl")
  302. num_tokens, hidden_size = query.shape
  303. # Reshape the query, key, and value tensors.
  304. query = query.view(-1, self.num_heads, self.head_size)
  305. key = key.view(-1, self.num_kv_heads, self.head_size)
  306. value = value.view(-1, self.num_kv_heads, self.head_size)
  307. if kv_cache is not None:
  308. key_cache, value_cache = PagedAttention.split_kv_cache(
  309. kv_cache, self.num_kv_heads, self.head_size)
  310. # Reshape the input keys and values and store them in the cache.
  311. # If kv_cache is not provided, the new key and value tensors are
  312. # not cached. This happens during the initial memory profiling run.
  313. PagedAttention.write_to_paged_cache(
  314. key,
  315. value,
  316. key_cache,
  317. value_cache,
  318. attn_metadata.slot_mapping,
  319. self.kv_cache_dtype,
  320. k_scale,
  321. v_scale,
  322. )
  323. num_prefill_tokens = attn_metadata.num_prefill_tokens
  324. num_decode_tokens = attn_metadata.num_decode_tokens
  325. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  326. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  327. output = torch.empty_like(query)
  328. # Query for decode. KV is not needed because it is already cached.
  329. decode_query = query[num_prefill_tokens:]
  330. # QKV for prefill.
  331. query = query[:num_prefill_tokens]
  332. key = key[:num_prefill_tokens]
  333. value = value[:num_prefill_tokens]
  334. assert query.shape[0] == num_prefill_tokens
  335. assert decode_query.shape[0] == num_decode_tokens
  336. if prefill_meta := attn_metadata.prefill_metadata:
  337. # Prompt run.
  338. assert prefill_meta.seq_lens is not None
  339. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  340. # triton attention
  341. # When block_tables are not filled, it means q and k are the
  342. # prompt, and they have the same length.
  343. attn_masks = None
  344. if self.use_triton_flash_attn:
  345. if self.alibi_slopes is not None:
  346. attn_masks = _make_alibi_bias(
  347. self.alibi_slopes,
  348. query.dtype,
  349. attn_metadata.seq_lens,
  350. make_attn_mask=False) # type: ignore
  351. out, _ = self.attn_func(
  352. query,
  353. key,
  354. value,
  355. None,
  356. prefill_meta.seq_start_loc,
  357. prefill_meta.seq_start_loc,
  358. prefill_meta.max_prefill_seq_len,
  359. prefill_meta.max_prefill_seq_len,
  360. True,
  361. self.scale,
  362. attn_masks[0][None]
  363. if attn_masks is not None else None,
  364. )
  365. elif self.use_naive_attn:
  366. if self.num_kv_heads != self.num_heads:
  367. # Interleave for MQA workaround.
  368. key = self.repeat_kv(key, self.num_queries_per_kv)
  369. value = self.repeat_kv(value, self.num_queries_per_kv)
  370. if self.alibi_slopes is not None:
  371. attn_masks = _make_alibi_bias(
  372. self.alibi_slopes,
  373. query.dtype,
  374. attn_metadata.seq_lens,
  375. make_attn_mask=True) # type: ignore
  376. query = query.movedim(0, query.dim() - 2)
  377. key = key.movedim(0, key.dim() - 2)
  378. value = value.movedim(0, value.dim() - 2)
  379. # sdpa math backend attention
  380. out = self.attn_func(
  381. query,
  382. key,
  383. value,
  384. prefill_meta.seq_lens,
  385. num_tokens,
  386. self.num_heads,
  387. self.head_size,
  388. self.scale,
  389. attn_masks,
  390. )
  391. else:
  392. out = self.attn_func(
  393. q=query,
  394. k=key,
  395. v=value,
  396. cu_seqlens_q=prefill_meta.seq_start_loc,
  397. cu_seqlens_k=prefill_meta.seq_start_loc,
  398. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  399. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  400. softmax_scale=self.scale,
  401. causal=True,
  402. window_size=self.sliding_window,
  403. alibi_slopes=self.alibi_slopes,
  404. )
  405. # common code for prefill
  406. assert output[:num_prefill_tokens].shape == out.shape
  407. output[:num_prefill_tokens] = out
  408. else:
  409. # prefix-enabled attention
  410. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  411. query,
  412. key,
  413. value,
  414. self.kv_cache_dtype,
  415. key_cache,
  416. value_cache,
  417. prefill_meta.block_tables,
  418. prefill_meta.query_start_loc,
  419. prefill_meta.seq_lens_tensor,
  420. prefill_meta.context_lens_tensor,
  421. prefill_meta.max_query_len,
  422. self.alibi_slopes,
  423. self.sliding_window[0],
  424. k_scale,
  425. v_scale,
  426. )
  427. if decode_meta := attn_metadata.decode_metadata:
  428. # Decoding run.
  429. # Whether to use rocm custom paged attention or not
  430. num_seqs, num_heads, head_size = decode_query.shape
  431. block_size = value_cache.shape[3]
  432. gqa_ratio = num_heads // self.num_kv_heads
  433. use_custom = use_rocm_custom_paged_attention(
  434. decode_query.dtype, head_size, block_size, self.kv_cache_dtype,
  435. gqa_ratio, decode_meta.max_decode_seq_len)
  436. if use_custom:
  437. max_seq_len = decode_meta.max_decode_seq_len
  438. max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
  439. _PARTITION_SIZE)
  440. assert _PARTITION_SIZE % block_size == 0
  441. tmp_output = torch.empty(
  442. size=(num_seqs, num_heads, max_num_partitions, head_size),
  443. dtype=output.dtype,
  444. device=output.device,
  445. )
  446. exp_sums = torch.empty(
  447. size=(num_seqs, num_heads, max_num_partitions),
  448. dtype=torch.float32,
  449. device=output.device,
  450. )
  451. max_logits = torch.empty_like(exp_sums)
  452. ops.paged_attention_rocm(
  453. output[num_prefill_tokens:],
  454. exp_sums,
  455. max_logits,
  456. tmp_output,
  457. decode_query,
  458. key_cache,
  459. value_cache,
  460. self.num_kv_heads,
  461. self.scale,
  462. decode_meta.block_tables,
  463. decode_meta.seq_lens_tensor,
  464. block_size,
  465. max_seq_len,
  466. self.alibi_slopes,
  467. self.kv_cache_dtype,
  468. )
  469. else:
  470. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  471. decode_query,
  472. key_cache,
  473. value_cache,
  474. decode_meta.block_tables,
  475. decode_meta.seq_lens_tensor,
  476. decode_meta.max_decode_seq_len,
  477. self.kv_cache_dtype,
  478. self.num_kv_heads,
  479. self.scale,
  480. self.alibi_slopes,
  481. k_scale,
  482. v_scale,
  483. )
  484. # Reshape the output tensor.
  485. return output.view(num_tokens, hidden_size)
  486. def _sdpa_attention(
  487. query: torch.Tensor,
  488. key: torch.Tensor,
  489. value: torch.Tensor,
  490. seq_lens: List[int],
  491. num_tokens: int,
  492. num_heads: int,
  493. head_size: int,
  494. scale: float,
  495. attn_masks: Optional[List[torch.Tensor]] = None,
  496. ) -> torch.Tensor:
  497. start = 0
  498. output = torch.empty((num_tokens, num_heads, head_size),
  499. dtype=query.dtype,
  500. device=query.device)
  501. for i, seq_len in enumerate(seq_lens):
  502. end = start + seq_len
  503. with torch.backends.cuda.sdp_kernel(enable_math=True,
  504. enable_flash=False,
  505. enable_mem_efficient=False):
  506. sub_out = torch.nn.functional.scaled_dot_product_attention(
  507. query[:, start:end, :],
  508. key[:, start:end, :],
  509. value[:, start:end, :],
  510. dropout_p=0.0,
  511. is_causal=attn_masks is None,
  512. attn_mask=attn_masks[i] if attn_masks else None,
  513. scale=scale).movedim(query.dim() - 2, 0)
  514. output[start:end, :, :] = sub_out
  515. start = end
  516. return output
  517. def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
  518. block_size: int, kv_cache_dtype: str,
  519. gqa_ratio: int, max_seq_len: int) -> bool:
  520. # rocm custom page attention not support on navi (gfx1*)
  521. return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
  522. and (head_size == 64 or head_size == 128)
  523. and (block_size == 16 or block_size == 32)
  524. and kv_cache_dtype == "auto"
  525. and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)