rocm_flash_attn.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. """Attention layer ROCm GPUs."""
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Any, Dict, List, Optional, Tuple, Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata,
  10. AttentionType)
  11. from aphrodite.attention.backends.utils import CommonMetadataBuilder
  12. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  13. PagedAttentionMetadata)
  14. class ROCmFlashAttentionBackend(AttentionBackend):
  15. @staticmethod
  16. def get_name() -> str:
  17. return "rocm-flash-attn"
  18. @staticmethod
  19. def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
  20. return ROCmFlashAttentionImpl
  21. @staticmethod
  22. def get_metadata_cls() -> Type["AttentionMetadata"]:
  23. return ROCmFlashAttentionMetadata
  24. @staticmethod
  25. def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
  26. return ROCmFlashAttentionMetadataBuilder
  27. @staticmethod
  28. def get_kv_cache_shape(
  29. num_blocks: int,
  30. block_size: int,
  31. num_kv_heads: int,
  32. head_size: int,
  33. ) -> Tuple[int, ...]:
  34. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  35. num_kv_heads, head_size)
  36. @staticmethod
  37. def swap_blocks(
  38. src_kv_cache: torch.Tensor,
  39. dst_kv_cache: torch.Tensor,
  40. src_to_dst: torch.Tensor,
  41. ) -> None:
  42. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  43. @staticmethod
  44. def copy_blocks(
  45. kv_caches: List[torch.Tensor],
  46. src_to_dists: torch.Tensor,
  47. ) -> None:
  48. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  49. @dataclass
  50. class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
  51. """Metadata for FlashAttentionBackend.
  52. NOTE: Any python object stored here is not updated when it is
  53. cuda-graph replayed. If you have values that need to be changed
  54. dynamically, it should be stored in tensor. The tensor has to be
  55. updated from `CUDAGraphRunner.forward` API.
  56. """
  57. # (batch_size,). The sequence length per sequence. Sequence length means
  58. # the computed tokens + new tokens None if it is a decoding.
  59. seq_lens: Optional[List[int]]
  60. # seq_lens stored as a tensor.
  61. seq_lens_tensor: Optional[torch.Tensor]
  62. # NOTE: Definition of context_len, query_len, and seq_len.
  63. # |---------- N-1 iteration --------|
  64. # |---------------- N iteration ---------------------|
  65. # |- tokenA -|......................|-- newTokens ---|
  66. # |---------- context_len ----------|
  67. # |-------------------- seq_len ----------------------|
  68. # |-- query_len ---|
  69. # Maximum query length in the batch. None for decoding.
  70. max_query_len: Optional[int]
  71. # Maximum sequence length among prefill batch. 0 if there are decoding
  72. # requests only.
  73. max_prefill_seq_len: int
  74. # Maximum sequence length among decode batch. 0 if there are prefill
  75. # requests only.
  76. max_decode_seq_len: int
  77. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  78. # the batch, used to index into subquery. E.g., if the subquery length
  79. # is [4, 6], it is [0, 4, 10].
  80. query_start_loc: Optional[torch.Tensor]
  81. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  82. # the batch, used to index into sequence. E.g., if the sequence length is
  83. # [4, 6], it is [0, 4, 10].
  84. seq_start_loc: Optional[torch.Tensor]
  85. # Whether or not if cuda graph is enabled.
  86. # Cuda-graph is currently enabled for decoding only.
  87. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  88. use_cuda_graph: bool
  89. # (batch_size,) A tensor of context lengths (tokens that are computed
  90. # so far).
  91. context_lens_tensor: Optional[torch.Tensor]
  92. _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  93. _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  94. @property
  95. def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  96. if self.num_prefills == 0:
  97. return None
  98. if self._cached_prefill_metadata is not None:
  99. return self._cached_prefill_metadata
  100. assert self.seq_lens is not None
  101. assert self.seq_lens_tensor is not None
  102. assert self.query_start_loc is not None
  103. assert self.context_lens_tensor is not None
  104. assert self.block_tables is not None
  105. assert self.seq_start_loc is not None
  106. self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
  107. num_prefills=self.num_prefills,
  108. num_prefill_tokens=self.num_prefill_tokens,
  109. num_decode_tokens=0,
  110. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  111. seq_lens=self.seq_lens[:self.num_prefills],
  112. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  113. max_query_len=self.max_query_len,
  114. max_prefill_seq_len=self.max_prefill_seq_len,
  115. max_decode_seq_len=0,
  116. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  117. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  118. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  119. block_tables=self.block_tables[:self.num_prefills],
  120. use_cuda_graph=False,
  121. )
  122. return self._cached_prefill_metadata
  123. @property
  124. def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  125. if self.num_decode_tokens == 0:
  126. return None
  127. if self._cached_decode_metadata is not None:
  128. return self._cached_decode_metadata
  129. assert self.block_tables is not None
  130. assert self.seq_lens_tensor is not None
  131. self._cached_decode_metadata = ROCmFlashAttentionMetadata(
  132. num_prefills=0,
  133. num_prefill_tokens=0,
  134. num_decode_tokens=self.num_decode_tokens,
  135. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  136. seq_lens=None,
  137. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  138. max_query_len=None,
  139. max_prefill_seq_len=0,
  140. max_decode_seq_len=self.max_decode_seq_len,
  141. query_start_loc=None,
  142. seq_start_loc=None,
  143. context_lens_tensor=None,
  144. block_tables=self.block_tables[self.num_prefills:],
  145. use_cuda_graph=self.use_cuda_graph,
  146. )
  147. return self._cached_decode_metadata
  148. class ROCmFlashAttentionMetadataBuilder(
  149. CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
  150. _metadata_cls = ROCmFlashAttentionMetadata
  151. def _make_alibi_bias(alibi_slopes: torch.Tensor,
  152. dtype: torch.dtype,
  153. seq_lens: Optional[List[int]],
  154. make_attn_mask: bool = True) -> List[torch.Tensor]:
  155. attn_biases = []
  156. if seq_lens:
  157. for seq_len in seq_lens:
  158. bias = torch.arange(seq_len, dtype=dtype)
  159. # NOTE(zhuohan): HF uses
  160. # `bias = bias[None, :].repeat(seq_len, 1)`
  161. # here. We find that both biases give the same results, but
  162. # the bias below more accurately follows the original ALiBi
  163. # paper.
  164. bias = bias[None, :] - bias[:, None]
  165. num_heads = alibi_slopes.shape[0]
  166. bias = bias[None, :].repeat(
  167. (num_heads, 1, 1)).to(alibi_slopes.device)
  168. bias.mul_(alibi_slopes[:, None, None])
  169. if make_attn_mask:
  170. inf_mask = torch.empty(
  171. (1, seq_len, seq_len),
  172. dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
  173. alibi_slopes.device)
  174. attn_biases.append((bias + inf_mask).to(dtype))
  175. else:
  176. attn_biases.append(bias.to(dtype))
  177. return attn_biases
  178. class ROCmFlashAttentionImpl(AttentionImpl):
  179. """
  180. If the input tensors contain prompt tokens, the layout is as follows:
  181. |<--------------- num_prompt_tokens -------------->|
  182. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  183. Otherwise, the layout is as follows:
  184. |<------------------ num_generation_tokens (M) ----------------->|
  185. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  186. Generation tokens can contain padding when cuda-graph is used.
  187. Currently, prompt tokens don't contain any padding.
  188. The prompts might have different lengths, while the generation tokens
  189. always have length 1.
  190. If chunked prefill is enabled, prefill tokens and decode tokens can be
  191. batched together in a flattened 1D query.
  192. |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
  193. |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
  194. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  195. padding between prefill and decode tokens.
  196. """
  197. def __init__(
  198. self,
  199. num_heads: int,
  200. head_size: int,
  201. scale: float,
  202. num_kv_heads: int,
  203. alibi_slopes: Optional[List[float]],
  204. sliding_window: Optional[int],
  205. kv_cache_dtype: str,
  206. blocksparse_params: Optional[Dict[str, Any]] = None,
  207. ) -> None:
  208. assert blocksparse_params is None, ValueError(
  209. "ROCmFlashAttention does not support blocksparse attention.")
  210. self.num_heads = num_heads
  211. self.head_size = head_size
  212. self.scale = float(scale)
  213. self.num_kv_heads = num_kv_heads
  214. if alibi_slopes is not None:
  215. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  216. self.alibi_slopes = alibi_slopes
  217. self.sliding_window = ((sliding_window, sliding_window)
  218. if sliding_window is not None else (-1, -1))
  219. self.kv_cache_dtype = kv_cache_dtype
  220. assert self.num_heads % self.num_kv_heads == 0
  221. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  222. supported_head_sizes = PagedAttention.get_supported_head_sizes()
  223. if head_size not in supported_head_sizes:
  224. raise ValueError(
  225. f"Head size {head_size} is not supported by PagedAttention. "
  226. f"Supported head sizes are: {supported_head_sizes}.")
  227. self.use_naive_attn = False
  228. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
  229. self.use_triton_flash_attn = (os.environ.get(
  230. "APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower()
  231. in ("true", "1"))
  232. if self.use_triton_flash_attn:
  233. from aphrodite.attention.ops.triton_flash_attn import \
  234. triton_attention # noqa: F401
  235. self.attn_func = triton_attention
  236. logger.debug("Using Triton FA in ROCmBackend")
  237. if self.sliding_window != (-1, -1):
  238. logger.warning("ROCm Triton FA does not currently support "
  239. "sliding window attention. If using half "
  240. "precision, please try using the ROCm CK "
  241. "FA backend instead by setting the env var "
  242. "`APHRODITE_USE_TRITON_FLASH_ATTN=0`")
  243. else:
  244. # if not using triton, navi3x/navi21/navi10 do not use flash-attn
  245. # either
  246. if torch.cuda.get_device_capability()[0] != 9:
  247. self.use_naive_attn = True
  248. else:
  249. try:
  250. from flash_attn import flash_attn_varlen_func # noqa: F401
  251. self.attn_func = flash_attn_varlen_func
  252. logger.debug("Using CK FA in ROCmBackend")
  253. except ModuleNotFoundError:
  254. self.use_naive_attn = True
  255. if self.use_naive_attn:
  256. self.attn_func = _sdpa_attention
  257. logger.debug("Using naive attention in ROCmBackend")
  258. def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
  259. """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
  260. tokens, n_kv_heads, head_dim = x.shape
  261. return (x[:, :,
  262. None, :].expand(tokens, n_kv_heads, n_rep,
  263. head_dim).reshape(tokens, n_kv_heads * n_rep,
  264. head_dim))
  265. def forward(
  266. self,
  267. query: torch.Tensor,
  268. key: torch.Tensor,
  269. value: torch.Tensor,
  270. kv_cache: torch.Tensor,
  271. attn_metadata: ROCmFlashAttentionMetadata,
  272. k_scale: float = 1.0,
  273. v_scale: float = 1.0,
  274. attn_type: AttentionType = AttentionType.DECODER,
  275. ) -> torch.Tensor:
  276. """Forward pass with FlashAttention and PagedAttention.
  277. Args:
  278. query: shape = [num_tokens, num_heads * head_size]
  279. key: shape = [num_tokens, num_kv_heads * head_size]
  280. value: shape = [num_tokens, num_kv_heads * head_size]
  281. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  282. attn_metadata: Metadata for attention.
  283. Returns:
  284. shape = [num_tokens, num_heads * head_size]
  285. """
  286. if attn_type != AttentionType.DECODER:
  287. raise NotImplementedError("Encoder self-attention and "
  288. "encoder/decoder cross-attention "
  289. "are not implemented for "
  290. "ROCmFlashAttentionImpl")
  291. num_tokens, hidden_size = query.shape
  292. # Reshape the query, key, and value tensors.
  293. query = query.view(-1, self.num_heads, self.head_size)
  294. key = key.view(-1, self.num_kv_heads, self.head_size)
  295. value = value.view(-1, self.num_kv_heads, self.head_size)
  296. if kv_cache is not None:
  297. key_cache, value_cache = PagedAttention.split_kv_cache(
  298. kv_cache, self.num_kv_heads, self.head_size)
  299. # Reshape the input keys and values and store them in the cache.
  300. # If kv_cache is not provided, the new key and value tensors are
  301. # not cached. This happens during the initial memory profiling run.
  302. PagedAttention.write_to_paged_cache(
  303. key,
  304. value,
  305. key_cache,
  306. value_cache,
  307. attn_metadata.slot_mapping,
  308. self.kv_cache_dtype,
  309. k_scale,
  310. v_scale,
  311. )
  312. num_prefill_tokens = attn_metadata.num_prefill_tokens
  313. num_decode_tokens = attn_metadata.num_decode_tokens
  314. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  315. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  316. output = torch.empty_like(query)
  317. # Query for decode. KV is not needed because it is already cached.
  318. decode_query = query[num_prefill_tokens:]
  319. # QKV for prefill.
  320. query = query[:num_prefill_tokens]
  321. key = key[:num_prefill_tokens]
  322. value = value[:num_prefill_tokens]
  323. assert query.shape[0] == num_prefill_tokens
  324. assert decode_query.shape[0] == num_decode_tokens
  325. if prefill_meta := attn_metadata.prefill_metadata:
  326. # Prompt run.
  327. assert prefill_meta.seq_lens is not None
  328. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  329. # triton attention
  330. # When block_tables are not filled, it means q and k are the
  331. # prompt, and they have the same length.
  332. attn_masks = None
  333. if self.use_triton_flash_attn:
  334. if self.alibi_slopes is not None:
  335. attn_masks = _make_alibi_bias(
  336. self.alibi_slopes,
  337. query.dtype,
  338. attn_metadata.seq_lens,
  339. make_attn_mask=False) # type: ignore
  340. out, _ = self.attn_func(
  341. query,
  342. key,
  343. value,
  344. None,
  345. prefill_meta.seq_start_loc,
  346. prefill_meta.seq_start_loc,
  347. prefill_meta.max_prefill_seq_len,
  348. prefill_meta.max_prefill_seq_len,
  349. True,
  350. self.scale,
  351. attn_masks[0][None]
  352. if attn_masks is not None else None,
  353. )
  354. elif self.use_naive_attn:
  355. if self.num_kv_heads != self.num_heads:
  356. # Interleave for MQA workaround.
  357. key = self.repeat_kv(key, self.num_queries_per_kv)
  358. value = self.repeat_kv(value, self.num_queries_per_kv)
  359. if self.alibi_slopes is not None:
  360. attn_masks = _make_alibi_bias(
  361. self.alibi_slopes,
  362. query.dtype,
  363. attn_metadata.seq_lens,
  364. make_attn_mask=True) # type: ignore
  365. query = query.movedim(0, query.dim() - 2)
  366. key = key.movedim(0, key.dim() - 2)
  367. value = value.movedim(0, value.dim() - 2)
  368. # sdpa math backend attention
  369. out = self.attn_func(
  370. query,
  371. key,
  372. value,
  373. prefill_meta.seq_lens,
  374. num_tokens,
  375. self.num_heads,
  376. self.head_size,
  377. self.scale,
  378. attn_masks,
  379. )
  380. else:
  381. out = self.attn_func(
  382. q=query,
  383. k=key,
  384. v=value,
  385. cu_seqlens_q=prefill_meta.seq_start_loc,
  386. cu_seqlens_k=prefill_meta.seq_start_loc,
  387. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  388. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  389. softmax_scale=self.scale,
  390. causal=True,
  391. window_size=self.sliding_window,
  392. alibi_slopes=self.alibi_slopes,
  393. )
  394. # common code for prefill
  395. assert output[:num_prefill_tokens].shape == out.shape
  396. output[:num_prefill_tokens] = out
  397. else:
  398. # prefix-enabled attention
  399. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  400. query,
  401. key,
  402. value,
  403. key_cache,
  404. value_cache,
  405. prefill_meta.block_tables,
  406. prefill_meta.query_start_loc,
  407. prefill_meta.seq_lens_tensor,
  408. prefill_meta.context_lens_tensor,
  409. prefill_meta.max_query_len,
  410. self.alibi_slopes,
  411. self.sliding_window[0],
  412. )
  413. if decode_meta := attn_metadata.decode_metadata:
  414. # Decoding run.
  415. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  416. decode_query,
  417. key_cache,
  418. value_cache,
  419. decode_meta.block_tables,
  420. decode_meta.seq_lens_tensor,
  421. decode_meta.max_decode_seq_len,
  422. self.kv_cache_dtype,
  423. self.num_kv_heads,
  424. self.scale,
  425. self.alibi_slopes,
  426. k_scale,
  427. v_scale,
  428. )
  429. # Reshape the output tensor.
  430. return output.view(num_tokens, hidden_size)
  431. def _sdpa_attention(
  432. query: torch.Tensor,
  433. key: torch.Tensor,
  434. value: torch.Tensor,
  435. seq_lens: List[int],
  436. num_tokens: int,
  437. num_heads: int,
  438. head_size: int,
  439. scale: float,
  440. attn_masks: Optional[List[torch.Tensor]] = None,
  441. ) -> torch.Tensor:
  442. start = 0
  443. output = torch.empty((num_tokens, num_heads, head_size),
  444. dtype=query.dtype,
  445. device=query.device)
  446. for i, seq_len in enumerate(seq_lens):
  447. end = start + seq_len
  448. with torch.backends.cuda.sdp_kernel(enable_math=True,
  449. enable_flash=False,
  450. enable_mem_efficient=False):
  451. sub_out = torch.nn.functional.scaled_dot_product_attention(
  452. query[:, start:end, :],
  453. key[:, start:end, :],
  454. value[:, start:end, :],
  455. dropout_p=0.0,
  456. is_causal=attn_masks is None,
  457. attn_mask=attn_masks[i] if attn_masks else None,
  458. scale=scale).movedim(query.dim() - 2, 0)
  459. output[start:end, :, :] = sub_out
  460. start = end
  461. return output