rocm_flash_attn.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. """Attention layer ROCm GPUs."""
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Any, Dict, List, Optional, Tuple, Type
  5. import torch
  6. from loguru import logger
  7. from aphrodite.attention.backends.abstract import (AttentionBackend,
  8. AttentionImpl,
  9. AttentionMetadata)
  10. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  11. PagedAttentionMetadata)
  12. class ROCmFlashAttentionBackend(AttentionBackend):
  13. @staticmethod
  14. def get_name() -> str:
  15. return "rocm-flash-attn"
  16. @staticmethod
  17. def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
  18. return ROCmFlashAttentionImpl
  19. @staticmethod
  20. def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
  21. return ROCmFlashAttentionMetadata(*args, **kwargs)
  22. @staticmethod
  23. def get_kv_cache_shape(
  24. num_blocks: int,
  25. block_size: int,
  26. num_kv_heads: int,
  27. head_size: int,
  28. ) -> Tuple[int, ...]:
  29. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  30. num_kv_heads, head_size)
  31. @staticmethod
  32. def swap_blocks(
  33. src_kv_cache: torch.Tensor,
  34. dst_kv_cache: torch.Tensor,
  35. src_to_dst: torch.Tensor,
  36. ) -> None:
  37. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  38. @staticmethod
  39. def copy_blocks(
  40. kv_caches: List[torch.Tensor],
  41. src_to_dists: torch.Tensor,
  42. ) -> None:
  43. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  44. @dataclass
  45. class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
  46. """Metadata for FlashAttentionBackend.
  47. NOTE: Any python object stored here is not updated when it is
  48. cuda-graph replayed. If you have values that need to be changed
  49. dynamically, it should be stored in tensor. The tensor has to be
  50. updated from `CUDAGraphRunner.forward` API.
  51. """
  52. # (batch_size,). The sequence length per sequence. Sequence length means
  53. # the computed tokens + new tokens None if it is a decoding.
  54. seq_lens: Optional[List[int]]
  55. # seq_lens stored as a tensor.
  56. seq_lens_tensor: Optional[torch.Tensor]
  57. # NOTE: Definition of context_len, query_len, and seq_len.
  58. # |---------- N-1 iteration --------|
  59. # |---------------- N iteration ---------------------|
  60. # |- tokenA -|......................|-- newTokens ---|
  61. # |---------- context_len ----------|
  62. # |-------------------- seq_len ----------------------|
  63. # |-- query_len ---|
  64. # Maximum query length in the batch. None for decoding.
  65. max_query_len: Optional[int]
  66. # Maximum sequence length among prefill batch. 0 if there are decoding
  67. # requests only.
  68. max_prefill_seq_len: int
  69. # Maximum sequence length among decode batch. 0 if there are prefill
  70. # requests only.
  71. max_decode_seq_len: int
  72. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  73. # the batch, used to index into subquery. E.g., if the subquery length
  74. # is [4, 6], it is [0, 4, 10].
  75. query_start_loc: Optional[torch.Tensor]
  76. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  77. # the batch, used to index into sequence. E.g., if the sequence length is
  78. # [4, 6], it is [0, 4, 10].
  79. seq_start_loc: Optional[torch.Tensor]
  80. # Whether or not if cuda graph is enabled.
  81. # Cuda-graph is currently enabled for decoding only.
  82. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  83. use_cuda_graph: bool
  84. # (batch_size,) A tensor of context lengths (tokens that are computed
  85. # so far).
  86. context_lens_tensor: Optional[torch.Tensor]
  87. _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  88. _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
  89. @property
  90. def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  91. if self.num_prefills == 0:
  92. return None
  93. if self._cached_prefill_metadata is not None:
  94. return self._cached_prefill_metadata
  95. assert self.seq_lens is not None
  96. assert self.seq_lens_tensor is not None
  97. assert self.query_start_loc is not None
  98. assert self.context_lens_tensor is not None
  99. assert self.block_tables is not None
  100. assert self.seq_start_loc is not None
  101. self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
  102. num_prefills=self.num_prefills,
  103. num_prefill_tokens=self.num_prefill_tokens,
  104. num_decode_tokens=0,
  105. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  106. seq_lens=self.seq_lens[:self.num_prefills],
  107. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  108. max_query_len=self.max_query_len,
  109. max_prefill_seq_len=self.max_prefill_seq_len,
  110. max_decode_seq_len=0,
  111. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  112. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  113. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  114. block_tables=self.block_tables[:self.num_prefills],
  115. use_cuda_graph=False,
  116. )
  117. return self._cached_prefill_metadata
  118. @property
  119. def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
  120. if self.num_decode_tokens == 0:
  121. return None
  122. if self._cached_decode_metadata is not None:
  123. return self._cached_decode_metadata
  124. assert self.block_tables is not None
  125. assert self.seq_lens_tensor is not None
  126. self._cached_decode_metadata = ROCmFlashAttentionMetadata(
  127. num_prefills=0,
  128. num_prefill_tokens=0,
  129. num_decode_tokens=self.num_decode_tokens,
  130. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  131. seq_lens=None,
  132. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  133. max_query_len=None,
  134. max_prefill_seq_len=0,
  135. max_decode_seq_len=self.max_decode_seq_len,
  136. query_start_loc=None,
  137. seq_start_loc=None,
  138. context_lens_tensor=None,
  139. block_tables=self.block_tables[self.num_prefills:],
  140. use_cuda_graph=self.use_cuda_graph,
  141. )
  142. return self._cached_decode_metadata
  143. class ROCmFlashAttentionImpl(AttentionImpl):
  144. """
  145. If the input tensors contain prompt tokens, the layout is as follows:
  146. |<--------------- num_prompt_tokens -------------->|
  147. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  148. Otherwise, the layout is as follows:
  149. |<------------------ num_generation_tokens (M) ----------------->|
  150. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  151. Generation tokens can contain padding when cuda-graph is used.
  152. Currently, prompt tokens don't contain any padding.
  153. The prompts might have different lengths, while the generation tokens
  154. always have length 1.
  155. If chunked prefill is enabled, prefill tokens and decode tokens can be
  156. batched together in a flattened 1D query.
  157. |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
  158. |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
  159. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  160. padding between prefill and decode tokens.
  161. """
  162. def __init__(
  163. self,
  164. num_heads: int,
  165. head_size: int,
  166. scale: float,
  167. num_kv_heads: int,
  168. alibi_slopes: Optional[List[float]],
  169. sliding_window: Optional[int],
  170. kv_cache_dtype: str,
  171. blocksparse_params: Optional[Dict[str, Any]] = None,
  172. ) -> None:
  173. assert blocksparse_params is None, ValueError(
  174. "ROCm FlashAttention does not support block-sparse attention.")
  175. self.num_heads = num_heads
  176. self.head_size = head_size
  177. self.scale = float(scale)
  178. self.num_kv_heads = num_kv_heads
  179. if alibi_slopes is not None:
  180. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  181. self.alibi_slopes = alibi_slopes
  182. self.sliding_window = ((sliding_window, sliding_window)
  183. if sliding_window is not None else (-1, -1))
  184. self.kv_cache_dtype = kv_cache_dtype
  185. assert self.num_heads % self.num_kv_heads == 0
  186. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  187. supported_head_sizes = PagedAttention.get_supported_head_sizes()
  188. if head_size not in supported_head_sizes:
  189. raise ValueError(
  190. f"Head size {head_size} is not supported by PagedAttention. "
  191. f"Supported head sizes are: {supported_head_sizes}.")
  192. self.use_naive_attn = False
  193. # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
  194. self.use_triton_flash_attn = (os.environ.get(
  195. "APHRODITE_USE_TRITON_FLASH_ATTN", "True").lower()
  196. in ("true", "1"))
  197. if self.use_triton_flash_attn:
  198. from aphrodite.attention.ops.triton_flash_attn import \
  199. triton_attention # noqa: F401
  200. self.attn_func = triton_attention
  201. logger.debug("Using Triton FA in ROCmBackend")
  202. else:
  203. # if not using triton, navi3x/navi21/navi10 do not use flash-attn
  204. # either
  205. if torch.cuda.get_device_capability()[0] != 9:
  206. self.use_naive_attn = True
  207. else:
  208. try:
  209. from flash_attn import flash_attn_varlen_func # noqa: F401
  210. self.attn_func = flash_attn_varlen_func
  211. logger.debug("Using CK FA in ROCmBackend")
  212. except ModuleNotFoundError:
  213. self.use_naive_attn = True
  214. if self.use_naive_attn:
  215. self.attn_func = _naive_attention
  216. logger.debug("Using naive attention in ROCmBackend")
  217. def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
  218. """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
  219. tokens, n_kv_heads, head_dim = x.shape
  220. return (x[:, :,
  221. None, :].expand(tokens, n_kv_heads, n_rep,
  222. head_dim).reshape(tokens, n_kv_heads * n_rep,
  223. head_dim))
  224. def forward(
  225. self,
  226. query: torch.Tensor,
  227. key: torch.Tensor,
  228. value: torch.Tensor,
  229. kv_cache: torch.Tensor,
  230. attn_metadata: ROCmFlashAttentionMetadata,
  231. kv_scale: float = 1.0,
  232. ) -> torch.Tensor:
  233. """Forward pass with FlashAttention and PagedAttention.
  234. Args:
  235. query: shape = [num_tokens, num_heads * head_size]
  236. key: shape = [num_tokens, num_kv_heads * head_size]
  237. value: shape = [num_tokens, num_kv_heads * head_size]
  238. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  239. attn_metadata: Metadata for attention.
  240. Returns:
  241. shape = [num_tokens, num_heads * head_size]
  242. """
  243. num_tokens, hidden_size = query.shape
  244. # Reshape the query, key, and value tensors.
  245. query = query.view(-1, self.num_heads, self.head_size)
  246. key = key.view(-1, self.num_kv_heads, self.head_size)
  247. value = value.view(-1, self.num_kv_heads, self.head_size)
  248. if kv_cache is not None:
  249. key_cache, value_cache = PagedAttention.split_kv_cache(
  250. kv_cache, self.num_kv_heads, self.head_size)
  251. # Reshape the input keys and values and store them in the cache.
  252. # If kv_cache is not provided, the new key and value tensors are
  253. # not cached. This happens during the initial memory profiling run.
  254. PagedAttention.write_to_paged_cache(
  255. key,
  256. value,
  257. key_cache,
  258. value_cache,
  259. attn_metadata.slot_mapping,
  260. self.kv_cache_dtype,
  261. kv_scale,
  262. )
  263. num_prefill_tokens = attn_metadata.num_prefill_tokens
  264. num_decode_tokens = attn_metadata.num_decode_tokens
  265. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  266. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  267. output = torch.empty_like(query)
  268. # Query for decode. KV is not needed because it is already cached.
  269. decode_query = query[num_prefill_tokens:]
  270. # QKV for prefill.
  271. query = query[:num_prefill_tokens]
  272. key = key[:num_prefill_tokens]
  273. value = value[:num_prefill_tokens]
  274. assert query.shape[0] == num_prefill_tokens
  275. assert decode_query.shape[0] == num_decode_tokens
  276. if prefill_meta := attn_metadata.prefill_metadata:
  277. # Prompt run.
  278. assert prefill_meta.seq_lens is not None
  279. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  280. # triton attention
  281. # When block_tables are not filled, it means q and k are the
  282. # prompt, and they have the same length.
  283. if self.use_triton_flash_attn:
  284. out, _ = self.attn_func(
  285. query,
  286. key,
  287. value,
  288. None,
  289. prefill_meta.seq_start_loc,
  290. prefill_meta.seq_start_loc,
  291. prefill_meta.max_prefill_seq_len,
  292. prefill_meta.max_prefill_seq_len,
  293. True,
  294. self.scale,
  295. )
  296. elif self.use_naive_attn:
  297. if self.num_kv_heads != self.num_heads:
  298. # Interleave for MQA workaround.
  299. key = self.repeat_kv(key, self.num_queries_per_kv)
  300. value = self.repeat_kv(value, self.num_queries_per_kv)
  301. out = self.attn_func(
  302. query,
  303. key,
  304. value,
  305. prefill_meta.seq_lens,
  306. self.scale,
  307. )
  308. else:
  309. out = self.attn_func(
  310. q=query,
  311. k=key,
  312. v=value,
  313. cu_seqlens_q=prefill_meta.seq_start_loc,
  314. cu_seqlens_k=prefill_meta.seq_start_loc,
  315. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  316. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  317. softmax_scale=self.scale,
  318. causal=True,
  319. )
  320. # common code for prefill
  321. assert output[:num_prefill_tokens].shape == out.shape
  322. output[:num_prefill_tokens] = out
  323. else:
  324. # prefix-enabled attention
  325. output[:num_prefill_tokens] = PagedAttention.forward_prefix(
  326. query,
  327. key,
  328. value,
  329. key_cache,
  330. value_cache,
  331. prefill_meta.block_tables,
  332. prefill_meta.query_start_loc,
  333. prefill_meta.seq_lens_tensor,
  334. prefill_meta.context_lens_tensor,
  335. prefill_meta.max_query_len,
  336. self.alibi_slopes,
  337. self.sliding_window[0],
  338. )
  339. if decode_meta := attn_metadata.decode_metadata:
  340. # Decoding run.
  341. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  342. decode_query,
  343. key_cache,
  344. value_cache,
  345. decode_meta.block_tables,
  346. decode_meta.seq_lens_tensor,
  347. decode_meta.max_decode_seq_len,
  348. self.kv_cache_dtype,
  349. self.num_kv_heads,
  350. self.scale,
  351. self.alibi_slopes,
  352. kv_scale,
  353. )
  354. # Reshape the output tensor.
  355. return output.view(num_tokens, hidden_size)
  356. def _naive_attention(
  357. query: torch.Tensor,
  358. key: torch.Tensor,
  359. value: torch.Tensor,
  360. seq_lens: List[int],
  361. scale: float,
  362. ) -> torch.Tensor:
  363. output = torch.empty_like(query)
  364. start = 0
  365. for _, seq_len in enumerate(seq_lens):
  366. end = start + seq_len
  367. out = _naive_masked_attention(
  368. query[start:end],
  369. key[start:end],
  370. value[start:end],
  371. scale,
  372. )
  373. # TODO: Unnecessary copy. Optimize.
  374. output[start:end].copy_(out)
  375. start += seq_len
  376. return output
  377. def _naive_masked_attention(
  378. query: torch.Tensor,
  379. key: torch.Tensor,
  380. value: torch.Tensor,
  381. scale: float,
  382. ) -> torch.Tensor:
  383. seq_len, head_size, head_dim = query.shape
  384. attn_mask = torch.triu(torch.ones(seq_len,
  385. seq_len,
  386. dtype=query.dtype,
  387. device=query.device),
  388. diagonal=1)
  389. attn_mask = attn_mask * torch.finfo(query.dtype).min
  390. attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
  391. attn_weights = attn_weights + attn_mask.float()
  392. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  393. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  394. return out