rocm_flash_attn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. """Attention layer ROCm GPUs."""
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple, Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import (
  8. AttentionBackend,
  9. AttentionImpl,
  10. AttentionMetadata,
  11. AttentionMetadataPerStage,
  12. )
  13. from aphrodite.attention.ops.paged_attn import (
  14. PagedAttention,
  15. PagedAttentionMetadata,
  16. )
  17. class ROCmFlashAttentionBackend(AttentionBackend):
  18. @staticmethod
  19. def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
  20. return ROCmFlashAttentionImpl
  21. @staticmethod
  22. def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
  23. return ROCmFlashAttentionMetadata(*args, **kwargs)
  24. @staticmethod
  25. def get_kv_cache_shape(
  26. num_blocks: int,
  27. block_size: int,
  28. num_kv_heads: int,
  29. head_size: int,
  30. ) -> Tuple[int, ...]:
  31. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  32. num_kv_heads, head_size)
  33. @staticmethod
  34. def swap_blocks(
  35. src_kv_cache: torch.Tensor,
  36. dst_kv_cache: torch.Tensor,
  37. src_to_dst: Dict[int, int],
  38. ) -> None:
  39. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  40. @staticmethod
  41. def copy_blocks(
  42. kv_caches: List[torch.Tensor],
  43. src_to_dists: Dict[int, List[int]],
  44. ) -> None:
  45. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  46. @dataclass
  47. class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
  48. PagedAttentionMetadata):
  49. """Metadata for FlashAttentionBackend.
  50. NOTE: Any python object stored here is not updated when it is
  51. cuda-graph replayed. If you have values that need to be changed
  52. dynamically, it should be stored in tensor. The tensor has to be
  53. updated from `CUDAGraphRunner.forward` API.
  54. """
  55. # Currently, input sequences can only contain all prompts
  56. # or all decoding. True if all sequences are prompts.
  57. is_prompt: bool
  58. # (batch_size,). The prompt length per sequence. None if it is a decoding.
  59. prompt_lens: Optional[List[int]]
  60. # prompt_lens stored as a tensor.
  61. prompt_lens_tensor: Optional[torch.Tensor]
  62. # NOTE: Definition of context_len, subquery_len, and seqlen.
  63. # |---------- N-1 iteration --------|
  64. # |---------------- N iteration ---------------------|
  65. # |- tokenA -|......................|-- newTokens ---|
  66. # |---------- context_len ----------|
  67. # |-------------------- seqlen ----------------------|
  68. # |- subquery_len -|
  69. # WARNING: context_len has different definition depending on if it is
  70. # prefill vs decoding. When it is prefill, it doesn't include new tokens.
  71. # When it is for decoding, it includes a new token.
  72. # Maximum subquery length in the batch.
  73. max_subquery_len: Optional[int]
  74. # Maximum prompt length in the batch.
  75. max_prompt_len: Optional[int]
  76. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  77. # the batch, used to index into subquery. E.g., if the subquery length
  78. # is [4, 6], it is [0, 4, 10].
  79. subquery_start_loc: Optional[torch.Tensor]
  80. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  81. # the batch, used to index into sequence. E.g., if the sequence length is
  82. # [4, 6], it is [0, 4, 10].
  83. seq_start_loc: Optional[torch.Tensor]
  84. # Whether or not if cuda graph is enabled.
  85. # Cuda-graph is currently enabled for decoding only.
  86. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  87. use_cuda_graph: bool
  88. class ROCmFlashAttentionImpl(AttentionImpl):
  89. """
  90. If the input tensors contain prompt tokens, the layout is as follows:
  91. |<--------------- num_prompt_tokens -------------->|
  92. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  93. Otherwise, the layout is as follows:
  94. |<------------------ num_generation_tokens (M) ----------------->|
  95. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  96. Generation tokens can contain padding when cuda-graph is used.
  97. Currently, prompt tokens don't contain any padding.
  98. The prompts might have different lengths, while the generation tokens
  99. always have length 1.
  100. If chunked prefill is enabled, prefill tokens and decode tokens can be
  101. batched together in a flattened 1D query.
  102. |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
  103. |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
  104. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  105. padding between prefill and decode tokens.
  106. """
  107. def __init__(
  108. self,
  109. num_heads: int,
  110. head_size: int,
  111. scale: float,
  112. num_kv_heads: Optional[int] = None,
  113. alibi_slopes: Optional[List[float]] = None,
  114. sliding_window: Optional[int] = None,
  115. ) -> None:
  116. self.num_heads = num_heads
  117. self.head_size = head_size
  118. self.scale = float(scale)
  119. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  120. self.sliding_window = ((sliding_window, sliding_window)
  121. if sliding_window is not None else (-1, -1))
  122. if alibi_slopes is not None:
  123. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  124. self.alibi_slopes = alibi_slopes
  125. assert self.num_heads % self.num_kv_heads == 0
  126. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  127. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  128. if head_size not in suppored_head_sizes:
  129. raise ValueError(
  130. f"Head size {head_size} is not supported by PagedAttention. "
  131. f"Supported head sizes are: {suppored_head_sizes}.")
  132. self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
  133. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
  134. self.use_triton_flash_attn = (os.environ.get(
  135. "APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower()
  136. in ("true", "1"))
  137. if self.use_naive_attn:
  138. # AMD Radeon 7900 series (gfx1100) currently does not support
  139. # xFormers nor FlashAttention. As a temporary workaround, we use
  140. # naive PyTorch implementation of attention.
  141. self.attn_fuc = _naive_attention()
  142. logger.debug("Using naive attention in ROCmBackend")
  143. elif self.use_triton_flash_attn:
  144. from aphrodite.attention.ops.triton_flash_attn import ( # noqa: F401
  145. triton_attention, )
  146. self.attn_func = triton_attention
  147. logger.debug("Using Triton FA in ROCmBackend")
  148. else:
  149. from flash_attn import flash_attn_varlen_func # noqa: F401
  150. self.attn_func = flash_attn_varlen_func
  151. logger.debug("Using CK FA in ROCmBackend")
  152. def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
  153. """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
  154. tokens, n_kv_heads, head_dim = x.shape
  155. return (x[:, :,
  156. None, :].expand(tokens, n_kv_heads, n_rep,
  157. head_dim).reshape(tokens, n_kv_heads * n_rep,
  158. head_dim))
  159. def forward(
  160. self,
  161. query: torch.Tensor,
  162. key: torch.Tensor,
  163. value: torch.Tensor,
  164. kv_cache: torch.Tensor,
  165. attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
  166. kv_scale: float = 1.0,
  167. ) -> torch.Tensor:
  168. """Forward pass with FlashAttention and PagedAttention.
  169. Args:
  170. query: shape = [num_tokens, num_heads * head_size]
  171. key: shape = [num_tokens, num_kv_heads * head_size]
  172. value: shape = [num_tokens, num_kv_heads * head_size]
  173. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  174. attn_metadata: Metadata for attention.
  175. Returns:
  176. shape = [num_tokens, num_heads * head_size]
  177. """
  178. num_tokens, hidden_size = query.shape
  179. # Reshape the query, key, and value tensors.
  180. query = query.view(-1, self.num_heads, self.head_size)
  181. key = key.view(-1, self.num_kv_heads, self.head_size)
  182. value = value.view(-1, self.num_kv_heads, self.head_size)
  183. if kv_cache is not None:
  184. key_cache, value_cache = PagedAttention.split_kv_cache(
  185. kv_cache, self.num_kv_heads, self.head_size)
  186. # Reshape the input keys and values and store them in the cache.
  187. # If kv_cache is not provided, the new key and value tensors are
  188. # not cached. This happens during the initial memory profiling run.
  189. PagedAttention.write_to_paged_cache(
  190. key,
  191. value,
  192. key_cache,
  193. value_cache,
  194. attn_metadata.slot_mapping,
  195. attn_metadata.kv_cache_dtype,
  196. kv_scale,
  197. )
  198. num_prefill_tokens = attn_metadata.num_prefill_tokens
  199. num_decode_tokens = attn_metadata.num_decode_tokens
  200. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  201. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  202. output = torch.empty_like(query)
  203. # Query for decode. KV is not needed because it is already cached.
  204. decode_query = query[num_prefill_tokens:]
  205. # QKV for prefill.
  206. query = query[:num_prefill_tokens]
  207. key = key[:num_prefill_tokens]
  208. value = value[:num_prefill_tokens]
  209. assert query.shape[0] == num_prefill_tokens
  210. assert decode_query.shape[0] == num_decode_tokens
  211. if prefill_meta := attn_metadata.prefill_metadata:
  212. # Prompt run.
  213. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  214. # triton attention
  215. # When block_tables are not filled, it means q and k are the
  216. # prompt, and they have the same length.
  217. if self.use_naive_attn or self.use_triton_flash_attn:
  218. if self.num_kv_heads != self.num_heads:
  219. # Interleave for MQA workaround.
  220. key = self.repeat_kv(key, self.num_queries_per_kv)
  221. value = self.repeat_kv(value, self.num_queries_per_kv)
  222. if self.use_naive_attn:
  223. out = self.attn_fuc(
  224. query,
  225. key,
  226. value,
  227. prefill_meta.prompt_lens,
  228. self.scale,
  229. )
  230. assert output[:num_prefill_tokens].shape == out.shape
  231. output[:num_prefill_tokens] = out
  232. else:
  233. out, _ = self.attn_func(
  234. query,
  235. key,
  236. value,
  237. None,
  238. prefill_meta.seq_start_loc,
  239. prefill_meta.seq_start_loc,
  240. prefill_meta.max_prompt_len,
  241. prefill_meta.max_prompt_len,
  242. True,
  243. self.scale,
  244. )
  245. assert output[:num_prefill_tokens].shape == out.shape
  246. output[:num_prefill_tokens] = out
  247. else:
  248. out = self.attn_func(
  249. q=query,
  250. k=key,
  251. v=value,
  252. cu_seqlens_q=prefill_meta.seq_start_loc,
  253. cu_seqlens_k=prefill_meta.seq_start_loc,
  254. max_seqlen_q=prefill_meta.max_prompt_len,
  255. max_seqlen_k=prefill_meta.max_prompt_len,
  256. softmax_scale=self.scale,
  257. causal=True,
  258. )
  259. assert output[:num_prefill_tokens].shape == out.shape
  260. output[:num_prefill_tokens] = out
  261. else:
  262. # prefix-enabled attention
  263. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  264. query,
  265. key,
  266. value,
  267. key_cache,
  268. value_cache,
  269. prefill_meta.block_tables,
  270. prefill_meta.subquery_start_loc,
  271. prefill_meta.prompt_lens_tensor,
  272. prefill_meta.context_lens,
  273. prefill_meta.max_subquery_len,
  274. self.alibi_slopes,
  275. )
  276. if decode_meta := attn_metadata.decode_metadata:
  277. # Decoding run.
  278. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  279. decode_query,
  280. key_cache,
  281. value_cache,
  282. decode_meta.block_tables,
  283. decode_meta.context_lens,
  284. decode_meta.max_context_len,
  285. attn_metadata.kv_cache_dtype,
  286. self.num_kv_heads,
  287. self.scale,
  288. self.alibi_slopes,
  289. kv_scale,
  290. )
  291. # Reshape the output tensor.
  292. return output.view(num_tokens, hidden_size)
  293. def _naive_attention(
  294. query: torch.Tensor,
  295. key: torch.Tensor,
  296. value: torch.Tensor,
  297. prompt_lens: List[int],
  298. scale: float,
  299. ) -> torch.Tensor:
  300. num_tokens = query.shape[0]
  301. output = torch.empty_like(query)
  302. start = 0
  303. for _, prompt_len in enumerate(prompt_lens):
  304. end = start + prompt_len
  305. out = _naive_masked_attention(
  306. query[None, start:end],
  307. key[None, start:end],
  308. value[None, start:end],
  309. scale,
  310. )
  311. # TODO: Unnecessary copy. Optimize.
  312. output[start:end].copy_(out)
  313. start += prompt_len
  314. # Using view got RuntimeError: view size is not compatible
  315. # with input tensor's size and stride (at least one
  316. # dimension spans across two contiguous subspaces).
  317. # Use reshape instead.
  318. return output.reshape(num_tokens, -1)
  319. def _naive_masked_attention(
  320. query: torch.Tensor,
  321. key: torch.Tensor,
  322. value: torch.Tensor,
  323. scale: float,
  324. ) -> torch.Tensor:
  325. seq_len, _, _ = query.shape
  326. attn_mask = torch.triu(torch.ones(seq_len,
  327. seq_len,
  328. dtype=query.dtype,
  329. device=query.device),
  330. diagonal=1)
  331. attn_mask = attn_mask * torch.finfo(query.dtype).min
  332. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  333. attn_weights = attn_weights + attn_mask.float()
  334. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  335. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  336. return out