rocm_flash_attn.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. """Attention layer ROCm GPUs."""
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple, Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata,
  10. AttentionMetadataPerStage)
  11. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  12. PagedAttentionMetadata)
  13. class ROCmFlashAttentionBackend(AttentionBackend):
  14. @staticmethod
  15. def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
  16. return ROCmFlashAttentionImpl
  17. @staticmethod
  18. def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
  19. return ROCmFlashAttentionMetadata(*args, **kwargs)
  20. @staticmethod
  21. def get_kv_cache_shape(
  22. num_blocks: int,
  23. block_size: int,
  24. num_kv_heads: int,
  25. head_size: int,
  26. ) -> Tuple[int, ...]:
  27. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  28. num_kv_heads, head_size)
  29. @staticmethod
  30. def swap_blocks(
  31. src_kv_cache: torch.Tensor,
  32. dst_kv_cache: torch.Tensor,
  33. src_to_dst: Dict[int, int],
  34. ) -> None:
  35. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  36. @staticmethod
  37. def copy_blocks(
  38. kv_caches: List[torch.Tensor],
  39. src_to_dists: Dict[int, List[int]],
  40. ) -> None:
  41. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  42. @dataclass
  43. class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
  44. PagedAttentionMetadata):
  45. """Metadata for FlashAttentionBackend.
  46. NOTE: Any python object stored here is not updated when it is
  47. cuda-graph replayed. If you have values that need to be changed
  48. dynamically, it should be stored in tensor. The tensor has to be
  49. updated from `CUDAGraphRunner.forward` API.
  50. """
  51. # Currently, input sequences can only contain all prompts
  52. # or all decoding. True if all sequences are prompts.
  53. is_prompt: bool
  54. # (batch_size,). The prompt length per sequence. None if it is a decoding.
  55. prompt_lens: Optional[List[int]]
  56. # prompt_lens stored as a tensor.
  57. prompt_lens_tensor: Optional[torch.Tensor]
  58. # NOTE: Definition of context_len, subquery_len, and seqlen.
  59. # |---------- N-1 iteration --------|
  60. # |---------------- N iteration ---------------------|
  61. # |- tokenA -|......................|-- newTokens ---|
  62. # |---------- context_len ----------|
  63. # |-------------------- seqlen ----------------------|
  64. # |- subquery_len -|
  65. # WARNING: context_len has different definition depending on if it is
  66. # prefill vs decoding. When it is prefill, it doesn't include new tokens.
  67. # When it is for decoding, it includes a new token.
  68. # Maximum subquery length in the batch.
  69. max_subquery_len: Optional[int]
  70. # Maximum prompt length in the batch.
  71. max_prompt_len: Optional[int]
  72. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  73. # the batch, used to index into subquery. E.g., if the subquery length
  74. # is [4, 6], it is [0, 4, 10].
  75. subquery_start_loc: Optional[torch.Tensor]
  76. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  77. # the batch, used to index into sequence. E.g., if the sequence length is
  78. # [4, 6], it is [0, 4, 10].
  79. seq_start_loc: Optional[torch.Tensor]
  80. # Whether or not if cuda graph is enabled.
  81. # Cuda-graph is currently enabled for decoding only.
  82. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  83. use_cuda_graph: bool
  84. class ROCmFlashAttentionImpl(AttentionImpl):
  85. """
  86. If the input tensors contain prompt tokens, the layout is as follows:
  87. |<--------------- num_prompt_tokens -------------->|
  88. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  89. Otherwise, the layout is as follows:
  90. |<------------------ num_generation_tokens (M) ----------------->|
  91. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  92. Generation tokens can contain padding when cuda-graph is used.
  93. Currently, prompt tokens don't contain any padding.
  94. The prompts might have different lengths, while the generation tokens
  95. always have length 1.
  96. If chunked prefill is enabled, prefill tokens and decode tokens can be
  97. batched together in a flattened 1D query.
  98. |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
  99. |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
  100. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  101. padding between prefill and decode tokens.
  102. """
  103. def __init__(
  104. self,
  105. num_heads: int,
  106. head_size: int,
  107. scale: float,
  108. num_kv_heads: Optional[int] = None,
  109. alibi_slopes: Optional[List[float]] = None,
  110. sliding_window: Optional[int] = None,
  111. ) -> None:
  112. self.num_heads = num_heads
  113. self.head_size = head_size
  114. self.scale = float(scale)
  115. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  116. self.sliding_window = ((sliding_window, sliding_window)
  117. if sliding_window is not None else (-1, -1))
  118. if alibi_slopes is not None:
  119. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  120. self.alibi_slopes = alibi_slopes
  121. assert self.num_heads % self.num_kv_heads == 0
  122. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  123. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  124. if head_size not in suppored_head_sizes:
  125. raise ValueError(
  126. f"Head size {head_size} is not supported by PagedAttention. "
  127. f"Supported head sizes are: {suppored_head_sizes}.")
  128. self.use_naive_attn = False
  129. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
  130. self.use_triton_flash_attn = (os.environ.get(
  131. "APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower()
  132. in ("true", "1"))
  133. if self.use_triton_flash_attn:
  134. from aphrodite.attention.ops.triton_flash_attn import \
  135. triton_attention # noqa: F401
  136. self.attn_func = triton_attention
  137. logger.debug("Using Triton FA in ROCmBackend")
  138. else:
  139. # if not using triton, navi3x not use flash-attn either
  140. if torch.cuda.get_device_capability()[0] == 11:
  141. self.use_naive_attn = True
  142. else:
  143. try:
  144. from flash_attn import flash_attn_varlen_func # noqa: F401
  145. self.attn_func = flash_attn_varlen_func
  146. logger.debug("Using CK FA in ROCmBackend")
  147. except ModuleNotFoundError:
  148. self.use_naive_attn = True
  149. if self.use_naive_attn:
  150. self.attn_func = _naive_attention
  151. logger.debug("Using naive attention in ROCmBackend")
  152. def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
  153. """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
  154. tokens, n_kv_heads, head_dim = x.shape
  155. return (x[:, :,
  156. None, :].expand(tokens, n_kv_heads, n_rep,
  157. head_dim).reshape(tokens, n_kv_heads * n_rep,
  158. head_dim))
  159. def forward(
  160. self,
  161. query: torch.Tensor,
  162. key: torch.Tensor,
  163. value: torch.Tensor,
  164. kv_cache: torch.Tensor,
  165. attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
  166. kv_scale: float = 1.0,
  167. ) -> torch.Tensor:
  168. """Forward pass with FlashAttention and PagedAttention.
  169. Args:
  170. query: shape = [num_tokens, num_heads * head_size]
  171. key: shape = [num_tokens, num_kv_heads * head_size]
  172. value: shape = [num_tokens, num_kv_heads * head_size]
  173. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  174. attn_metadata: Metadata for attention.
  175. Returns:
  176. shape = [num_tokens, num_heads * head_size]
  177. """
  178. num_tokens, hidden_size = query.shape
  179. # Reshape the query, key, and value tensors.
  180. query = query.view(-1, self.num_heads, self.head_size)
  181. key = key.view(-1, self.num_kv_heads, self.head_size)
  182. value = value.view(-1, self.num_kv_heads, self.head_size)
  183. if kv_cache is not None:
  184. key_cache, value_cache = PagedAttention.split_kv_cache(
  185. kv_cache, self.num_kv_heads, self.head_size)
  186. # Reshape the input keys and values and store them in the cache.
  187. # If kv_cache is not provided, the new key and value tensors are
  188. # not cached. This happens during the initial memory profiling run.
  189. PagedAttention.write_to_paged_cache(
  190. key,
  191. value,
  192. key_cache,
  193. value_cache,
  194. attn_metadata.slot_mapping,
  195. attn_metadata.kv_cache_dtype,
  196. kv_scale,
  197. )
  198. num_prefill_tokens = attn_metadata.num_prefill_tokens
  199. num_decode_tokens = attn_metadata.num_decode_tokens
  200. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  201. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  202. output = torch.empty_like(query)
  203. # Query for decode. KV is not needed because it is already cached.
  204. decode_query = query[num_prefill_tokens:]
  205. # QKV for prefill.
  206. query = query[:num_prefill_tokens]
  207. key = key[:num_prefill_tokens]
  208. value = value[:num_prefill_tokens]
  209. assert query.shape[0] == num_prefill_tokens
  210. assert decode_query.shape[0] == num_decode_tokens
  211. if prefill_meta := attn_metadata.prefill_metadata:
  212. # Prompt run.
  213. assert prefill_meta.prompt_lens is not None
  214. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  215. # triton attention
  216. # When block_tables are not filled, it means q and k are the
  217. # prompt, and they have the same length.
  218. if self.use_triton_flash_attn:
  219. out, _ = self.attn_func(
  220. query,
  221. key,
  222. value,
  223. None,
  224. prefill_meta.seq_start_loc,
  225. prefill_meta.seq_start_loc,
  226. prefill_meta.max_prompt_len,
  227. prefill_meta.max_prompt_len,
  228. True,
  229. self.scale,
  230. )
  231. elif self.use_naive_attn:
  232. if self.num_kv_heads != self.num_heads:
  233. # Interleave for MQA workaround.
  234. key = self.repeat_kv(key, self.num_queries_per_kv)
  235. value = self.repeat_kv(value, self.num_queries_per_kv)
  236. out = self.attn_func(
  237. query,
  238. key,
  239. value,
  240. prefill_meta.prompt_lens,
  241. self.scale,
  242. )
  243. else:
  244. out = self.attn_func(
  245. q=query,
  246. k=key,
  247. v=value,
  248. cu_seqlens_q=prefill_meta.seq_start_loc,
  249. cu_seqlens_k=prefill_meta.seq_start_loc,
  250. max_seqlen_q=prefill_meta.max_prompt_len,
  251. max_seqlen_k=prefill_meta.max_prompt_len,
  252. softmax_scale=self.scale,
  253. causal=True,
  254. )
  255. # common code for prefill
  256. assert output[:num_prefill_tokens].shape == out.shape
  257. output[:num_prefill_tokens] = out
  258. else:
  259. # prefix-enabled attention
  260. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  261. query,
  262. key,
  263. value,
  264. key_cache,
  265. value_cache,
  266. prefill_meta.block_tables,
  267. prefill_meta.subquery_start_loc,
  268. prefill_meta.prompt_lens_tensor,
  269. prefill_meta.context_lens,
  270. prefill_meta.max_subquery_len,
  271. self.alibi_slopes,
  272. )
  273. if decode_meta := attn_metadata.decode_metadata:
  274. # Decoding run.
  275. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  276. decode_query,
  277. key_cache,
  278. value_cache,
  279. decode_meta.block_tables,
  280. decode_meta.context_lens,
  281. decode_meta.max_context_len,
  282. attn_metadata.kv_cache_dtype,
  283. self.num_kv_heads,
  284. self.scale,
  285. self.alibi_slopes,
  286. kv_scale,
  287. )
  288. # Reshape the output tensor.
  289. return output.view(num_tokens, hidden_size)
  290. def _naive_attention(
  291. query: torch.Tensor,
  292. key: torch.Tensor,
  293. value: torch.Tensor,
  294. prompt_lens: List[int],
  295. scale: float,
  296. ) -> torch.Tensor:
  297. output = torch.empty_like(query)
  298. start = 0
  299. for _, prompt_len in enumerate(prompt_lens):
  300. end = start + prompt_len
  301. out = _naive_masked_attention(
  302. query[start:end],
  303. key[start:end],
  304. value[start:end],
  305. scale,
  306. )
  307. # TODO: Unnecessary copy. Optimize.
  308. output[start:end].copy_(out)
  309. start += prompt_len
  310. return output
  311. def _naive_masked_attention(
  312. query: torch.Tensor,
  313. key: torch.Tensor,
  314. value: torch.Tensor,
  315. scale: float,
  316. ) -> torch.Tensor:
  317. seq_len, head_size, head_dim = query.shape
  318. attn_mask = torch.triu(torch.ones(seq_len,
  319. seq_len,
  320. dtype=query.dtype,
  321. device=query.device),
  322. diagonal=1)
  323. attn_mask = attn_mask * torch.finfo(query.dtype).min
  324. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  325. attn_weights = attn_weights + attn_mask.float()
  326. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  327. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  328. return out