rocm_flash_attn.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. """Attention layer ROCm GPUs."""
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from loguru import logger
  6. import aphrodite.common.envs as envs
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata,
  10. AttentionType)
  11. from aphrodite.attention.backends.utils import (CommonAttentionState,
  12. CommonMetadataBuilder)
  13. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  14. PagedAttentionMetadata)
  15. class ROCmFlashAttentionBackend(AttentionBackend):
  16. @staticmethod
  17. def get_name() -> str:
  18. return "rocm-flash-attn"
  19. @staticmethod
  20. def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
  21. return ROCmFlashAttentionImpl
  22. @staticmethod
  23. def get_metadata_cls() -> Type["AttentionMetadata"]:
  24. return ROCmFlashAttentionMetadata
  25. @staticmethod
  26. def get_state_cls() -> Type["CommonAttentionState"]:
  27. return CommonAttentionState
  28. @staticmethod
  29. def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
  30. return ROCmFlashAttentionMetadataBuilder
  31. @staticmethod
  32. def get_kv_cache_shape(
  33. num_blocks: int,
  34. block_size: int,
  35. num_kv_heads: int,
  36. head_size: int,
  37. ) -> Tuple[int, ...]:
  38. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  39. num_kv_heads, head_size)
  40. @staticmethod
  41. def swap_blocks(
  42. src_kv_cache: torch.Tensor,
  43. dst_kv_cache: torch.Tensor,
  44. src_to_dst: torch.Tensor,
  45. ) -> None:
  46. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  47. @staticmethod
  48. def copy_blocks(
  49. kv_caches: List[torch.Tensor],
  50. src_to_dists: torch.Tensor,
  51. ) -> None:
  52. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  53. @dataclass
  54. class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
  55. """Metadata for FlashAttentionBackend.
  56. NOTE: Any python object stored here is not updated when it is
  57. cuda-graph replayed. If you have values that need to be changed
  58. dynamically, it should be stored in tensor. The tensor has to be
  59. updated from `CUDAGraphRunner.forward` API.
  60. """
  61. # (batch_size,). The sequence length per sequence. Sequence length means
  62. # the computed tokens + new tokens None if it is a decoding.
  63. seq_lens: Optional[List[int]]
  64. # seq_lens stored as a tensor.
  65. seq_lens_tensor: Optional[torch.Tensor]
  66. # NOTE: Definition of context_len, query_len, and seq_len.
  67. # |---------- N-1 iteration --------|
  68. # |---------------- N iteration ---------------------|
  69. # |- tokenA -|......................|-- newTokens ---|
  70. # |---------- context_len ----------|
  71. # |-------------------- seq_len ----------------------|
  72. # |-- query_len ---|
  73. # Maximum query length in the batch. None for decoding.
  74. max_query_len: Optional[int]
  75. # Maximum sequence length among prefill batch. 0 if there are decoding
  76. # requests only.
  77. max_prefill_seq_len: int
  78. # Maximum sequence length among decode batch. 0 if there are prefill
  79. # requests only.
  80. max_decode_seq_len: int
  81. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  82. # the batch, used to index into subquery. E.g., if the subquery length
  83. # is [4, 6], it is [0, 4, 10].
  84. query_start_loc: Optional[torch.Tensor]
  85. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  86. # the batch, used to index into sequence. E.g., if the sequence length is
  87. # [4, 6], it is [0, 4, 10].
  88. seq_start_loc: Optional[torch.Tensor]
  89. # Whether or not if cuda graph is enabled.
  90. # Cuda-graph is currently enabled for decoding only.
  91. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  92. use_cuda_graph: bool
  93. # (batch_size,) A tensor of context lengths (tokens that are computed
  94. # so far).
  95. context_lens_tensor: Optional[torch.Tensor]
  96. _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  97. _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  98. @property
  99. def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  100. if self.num_prefills == 0:
  101. return None
  102. if self._cached_prefill_metadata is not None:
  103. return self._cached_prefill_metadata
  104. assert self.seq_lens is not None
  105. assert self.seq_lens_tensor is not None
  106. assert self.query_start_loc is not None
  107. assert self.context_lens_tensor is not None
  108. assert self.block_tables is not None
  109. assert self.seq_start_loc is not None
  110. self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
  111. num_prefills=self.num_prefills,
  112. num_prefill_tokens=self.num_prefill_tokens,
  113. num_decode_tokens=0,
  114. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  115. seq_lens=self.seq_lens[:self.num_prefills],
  116. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  117. max_query_len=self.max_query_len,
  118. max_prefill_seq_len=self.max_prefill_seq_len,
  119. max_decode_seq_len=0,
  120. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  121. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  122. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  123. block_tables=self.block_tables[:self.num_prefills],
  124. use_cuda_graph=False,
  125. )
  126. return self._cached_prefill_metadata
  127. @property
  128. def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  129. if self.num_decode_tokens == 0:
  130. return None
  131. if self._cached_decode_metadata is not None:
  132. return self._cached_decode_metadata
  133. assert self.block_tables is not None
  134. assert self.seq_lens_tensor is not None
  135. self._cached_decode_metadata = ROCmFlashAttentionMetadata(
  136. num_prefills=0,
  137. num_prefill_tokens=0,
  138. num_decode_tokens=self.num_decode_tokens,
  139. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  140. seq_lens=None,
  141. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  142. max_query_len=None,
  143. max_prefill_seq_len=0,
  144. max_decode_seq_len=self.max_decode_seq_len,
  145. query_start_loc=None,
  146. seq_start_loc=None,
  147. context_lens_tensor=None,
  148. block_tables=self.block_tables[self.num_prefills:],
  149. use_cuda_graph=self.use_cuda_graph,
  150. )
  151. return self._cached_decode_metadata
  152. class ROCmFlashAttentionMetadataBuilder(
  153. CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
  154. _metadata_cls = ROCmFlashAttentionMetadata
  155. def _make_alibi_bias(alibi_slopes: torch.Tensor,
  156. dtype: torch.dtype,
  157. seq_lens: Optional[List[int]],
  158. make_attn_mask: bool = True) -> List[torch.Tensor]:
  159. attn_biases = []
  160. if seq_lens:
  161. for seq_len in seq_lens:
  162. bias = torch.arange(seq_len, dtype=dtype)
  163. # NOTE(zhuohan): HF uses
  164. # `bias = bias[None, :].repeat(seq_len, 1)`
  165. # here. We find that both biases give the same results, but
  166. # the bias below more accurately follows the original ALiBi
  167. # paper.
  168. bias = bias[None, :] - bias[:, None]
  169. num_heads = alibi_slopes.shape[0]
  170. bias = bias[None, :].repeat(
  171. (num_heads, 1, 1)).to(alibi_slopes.device)
  172. bias.mul_(alibi_slopes[:, None, None])
  173. if make_attn_mask:
  174. inf_mask = torch.empty(
  175. (1, seq_len, seq_len),
  176. dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
  177. alibi_slopes.device)
  178. attn_biases.append((bias + inf_mask).to(dtype))
  179. else:
  180. attn_biases.append(bias.to(dtype))
  181. return attn_biases
  182. class ROCmFlashAttentionImpl(AttentionImpl):
  183. """
  184. If the input tensors contain prompt tokens, the layout is as follows:
  185. |<--------------- num_prompt_tokens -------------->|
  186. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  187. Otherwise, the layout is as follows:
  188. |<------------------ num_generation_tokens (M) ----------------->|
  189. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  190. Generation tokens can contain padding when cuda-graph is used.
  191. Currently, prompt tokens don't contain any padding.
  192. The prompts might have different lengths, while the generation tokens
  193. always have length 1.
  194. If chunked prefill is enabled, prefill tokens and decode tokens can be
  195. batched together in a flattened 1D query.
  196. |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
  197. |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
  198. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  199. padding between prefill and decode tokens.
  200. """
  201. def __init__(
  202. self,
  203. num_heads: int,
  204. head_size: int,
  205. scale: float,
  206. num_kv_heads: int,
  207. alibi_slopes: Optional[List[float]],
  208. sliding_window: Optional[int],
  209. kv_cache_dtype: str,
  210. blocksparse_params: Optional[Dict[str, Any]] = None,
  211. logits_soft_cap: Optional[float] = None,
  212. ) -> None:
  213. if blocksparse_params is not None:
  214. raise ValueError(
  215. "ROCmFlashAttention does not support blocksparse attention.")
  216. if logits_soft_cap is not None:
  217. raise ValueError(
  218. "ROCmFlashAttention does not support attention logits soft "
  219. "capping.")
  220. self.num_heads = num_heads
  221. self.head_size = head_size
  222. self.scale = float(scale)
  223. self.num_kv_heads = num_kv_heads
  224. if alibi_slopes is not None:
  225. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  226. self.alibi_slopes = alibi_slopes
  227. self.sliding_window = ((sliding_window, sliding_window)
  228. if sliding_window is not None else (-1, -1))
  229. self.kv_cache_dtype = kv_cache_dtype
  230. assert self.num_heads % self.num_kv_heads == 0
  231. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  232. supported_head_sizes = PagedAttention.get_supported_head_sizes()
  233. if head_size not in supported_head_sizes:
  234. raise ValueError(
  235. f"Head size {head_size} is not supported by PagedAttention. "
  236. f"Supported head sizes are: {supported_head_sizes}.")
  237. self.use_naive_attn = False
  238. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
  239. self.use_triton_flash_attn = envs.APHRODITE_USE_TRITON_FLASH_ATTN
  240. if self.use_triton_flash_attn:
  241. from aphrodite.attention.ops.triton_flash_attn import ( # noqa: F401
  242. triton_attention)
  243. self.attn_func = triton_attention
  244. logger.debug("Using Triton FA in ROCmBackend")
  245. if self.sliding_window != (-1, -1):
  246. logger.warning("ROCm Triton FA does not currently support "
  247. "sliding window attention. If using half "
  248. "precision, please try using the ROCm CK "
  249. "FA backend instead by setting the env var "
  250. "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
  251. else:
  252. # if not using triton, navi3x/navi21/navi10 do not use flash-attn
  253. # either
  254. if torch.cuda.get_device_capability()[0] != 9:
  255. self.use_naive_attn = True
  256. else:
  257. try:
  258. from flash_attn import flash_attn_varlen_func # noqa: F401
  259. self.attn_func = flash_attn_varlen_func
  260. logger.debug("Using CK FA in ROCmBackend")
  261. except ModuleNotFoundError:
  262. self.use_naive_attn = True
  263. if self.use_naive_attn:
  264. self.attn_func = _sdpa_attention
  265. logger.debug("Using naive attention in ROCmBackend")
  266. def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
  267. """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
  268. tokens, n_kv_heads, head_dim = x.shape
  269. return (x[:, :,
  270. None, :].expand(tokens, n_kv_heads, n_rep,
  271. head_dim).reshape(tokens, n_kv_heads * n_rep,
  272. head_dim))
  273. def forward(
  274. self,
  275. query: torch.Tensor,
  276. key: torch.Tensor,
  277. value: torch.Tensor,
  278. kv_cache: torch.Tensor,
  279. attn_metadata: ROCmFlashAttentionMetadata,
  280. k_scale: float = 1.0,
  281. v_scale: float = 1.0,
  282. attn_type: AttentionType = AttentionType.DECODER,
  283. ) -> torch.Tensor:
  284. """Forward pass with FlashAttention and PagedAttention.
  285. Args:
  286. query: shape = [num_tokens, num_heads * head_size]
  287. key: shape = [num_tokens, num_kv_heads * head_size]
  288. value: shape = [num_tokens, num_kv_heads * head_size]
  289. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  290. attn_metadata: Metadata for attention.
  291. Returns:
  292. shape = [num_tokens, num_heads * head_size]
  293. """
  294. if attn_type != AttentionType.DECODER:
  295. raise NotImplementedError("Encoder self-attention and "
  296. "encoder/decoder cross-attention "
  297. "are not implemented for "
  298. "ROCmFlashAttentionImpl")
  299. num_tokens, hidden_size = query.shape
  300. # Reshape the query, key, and value tensors.
  301. query = query.view(-1, self.num_heads, self.head_size)
  302. key = key.view(-1, self.num_kv_heads, self.head_size)
  303. value = value.view(-1, self.num_kv_heads, self.head_size)
  304. if kv_cache is not None:
  305. key_cache, value_cache = PagedAttention.split_kv_cache(
  306. kv_cache, self.num_kv_heads, self.head_size)
  307. # Reshape the input keys and values and store them in the cache.
  308. # If kv_cache is not provided, the new key and value tensors are
  309. # not cached. This happens during the initial memory profiling run.
  310. PagedAttention.write_to_paged_cache(
  311. key,
  312. value,
  313. key_cache,
  314. value_cache,
  315. attn_metadata.slot_mapping,
  316. self.kv_cache_dtype,
  317. k_scale,
  318. v_scale,
  319. )
  320. num_prefill_tokens = attn_metadata.num_prefill_tokens
  321. num_decode_tokens = attn_metadata.num_decode_tokens
  322. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  323. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  324. output = torch.empty_like(query)
  325. # Query for decode. KV is not needed because it is already cached.
  326. decode_query = query[num_prefill_tokens:]
  327. # QKV for prefill.
  328. query = query[:num_prefill_tokens]
  329. key = key[:num_prefill_tokens]
  330. value = value[:num_prefill_tokens]
  331. assert query.shape[0] == num_prefill_tokens
  332. assert decode_query.shape[0] == num_decode_tokens
  333. if prefill_meta := attn_metadata.prefill_metadata:
  334. # Prompt run.
  335. assert prefill_meta.seq_lens is not None
  336. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  337. # triton attention
  338. # When block_tables are not filled, it means q and k are the
  339. # prompt, and they have the same length.
  340. attn_masks = None
  341. if self.use_triton_flash_attn:
  342. if self.alibi_slopes is not None:
  343. attn_masks = _make_alibi_bias(
  344. self.alibi_slopes,
  345. query.dtype,
  346. attn_metadata.seq_lens,
  347. make_attn_mask=False) # type: ignore
  348. out, _ = self.attn_func(
  349. query,
  350. key,
  351. value,
  352. None,
  353. prefill_meta.seq_start_loc,
  354. prefill_meta.seq_start_loc,
  355. prefill_meta.max_prefill_seq_len,
  356. prefill_meta.max_prefill_seq_len,
  357. True,
  358. self.scale,
  359. attn_masks[0][None]
  360. if attn_masks is not None else None,
  361. )
  362. elif self.use_naive_attn:
  363. if self.num_kv_heads != self.num_heads:
  364. # Interleave for MQA workaround.
  365. key = self.repeat_kv(key, self.num_queries_per_kv)
  366. value = self.repeat_kv(value, self.num_queries_per_kv)
  367. if self.alibi_slopes is not None:
  368. attn_masks = _make_alibi_bias(
  369. self.alibi_slopes,
  370. query.dtype,
  371. attn_metadata.seq_lens,
  372. make_attn_mask=True) # type: ignore
  373. query = query.movedim(0, query.dim() - 2)
  374. key = key.movedim(0, key.dim() - 2)
  375. value = value.movedim(0, value.dim() - 2)
  376. # sdpa math backend attention
  377. out = self.attn_func(
  378. query,
  379. key,
  380. value,
  381. prefill_meta.seq_lens,
  382. num_tokens,
  383. self.num_heads,
  384. self.head_size,
  385. self.scale,
  386. attn_masks,
  387. )
  388. else:
  389. out = self.attn_func(
  390. q=query,
  391. k=key,
  392. v=value,
  393. cu_seqlens_q=prefill_meta.seq_start_loc,
  394. cu_seqlens_k=prefill_meta.seq_start_loc,
  395. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  396. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  397. softmax_scale=self.scale,
  398. causal=True,
  399. window_size=self.sliding_window,
  400. alibi_slopes=self.alibi_slopes,
  401. )
  402. # common code for prefill
  403. assert output[:num_prefill_tokens].shape == out.shape
  404. output[:num_prefill_tokens] = out
  405. else:
  406. # prefix-enabled attention
  407. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  408. query,
  409. key,
  410. value,
  411. self.kv_cache_dtype,
  412. key_cache,
  413. value_cache,
  414. prefill_meta.block_tables,
  415. prefill_meta.query_start_loc,
  416. prefill_meta.seq_lens_tensor,
  417. prefill_meta.context_lens_tensor,
  418. prefill_meta.max_query_len,
  419. self.alibi_slopes,
  420. self.sliding_window[0],
  421. k_scale,
  422. v_scale,
  423. )
  424. if decode_meta := attn_metadata.decode_metadata:
  425. # Decoding run.
  426. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  427. decode_query,
  428. key_cache,
  429. value_cache,
  430. decode_meta.block_tables,
  431. decode_meta.seq_lens_tensor,
  432. decode_meta.max_decode_seq_len,
  433. self.kv_cache_dtype,
  434. self.num_kv_heads,
  435. self.scale,
  436. self.alibi_slopes,
  437. k_scale,
  438. v_scale,
  439. )
  440. # Reshape the output tensor.
  441. return output.view(num_tokens, hidden_size)
  442. def _sdpa_attention(
  443. query: torch.Tensor,
  444. key: torch.Tensor,
  445. value: torch.Tensor,
  446. seq_lens: List[int],
  447. num_tokens: int,
  448. num_heads: int,
  449. head_size: int,
  450. scale: float,
  451. attn_masks: Optional[List[torch.Tensor]] = None,
  452. ) -> torch.Tensor:
  453. start = 0
  454. output = torch.empty((num_tokens, num_heads, head_size),
  455. dtype=query.dtype,
  456. device=query.device)
  457. for i, seq_len in enumerate(seq_lens):
  458. end = start + seq_len
  459. with torch.backends.cuda.sdp_kernel(enable_math=True,
  460. enable_flash=False,
  461. enable_mem_efficient=False):
  462. sub_out = torch.nn.functional.scaled_dot_product_attention(
  463. query[:, start:end, :],
  464. key[:, start:end, :],
  465. value[:, start:end, :],
  466. dropout_p=0.0,
  467. is_causal=attn_masks is None,
  468. attn_mask=attn_masks[i] if attn_masks else None,
  469. scale=scale).movedim(query.dim() - 2, 0)
  470. output[start:end, :, :] = sub_out
  471. start = end
  472. return output