xformers.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. """Attention layer with xFormers and PagedAttention."""
  2. from dataclasses import dataclass
  3. from typing import Dict, List, Optional, Tuple, Type
  4. import torch
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import (AttentionBias,
  7. BlockDiagonalCausalMask,
  8. LowerTriangularMaskWithTensorBias)
  9. from aphrodite.attention.backends.abstract import (AttentionBackend,
  10. AttentionImpl,
  11. AttentionMetadata,
  12. AttentionMetadataPerStage)
  13. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  14. PagedAttentionMetadata)
  15. class XFormersBackend(AttentionBackend):
  16. @staticmethod
  17. def get_impl_cls() -> Type["XFormersImpl"]:
  18. return XFormersImpl
  19. @staticmethod
  20. def make_metadata(*args, **kwargs) -> "XFormersMetadata":
  21. return XFormersMetadata(*args, **kwargs)
  22. @staticmethod
  23. def get_kv_cache_shape(
  24. num_blocks: int,
  25. block_size: int,
  26. num_kv_heads: int,
  27. head_size: int,
  28. ) -> Tuple[int, ...]:
  29. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  30. num_kv_heads, head_size)
  31. @staticmethod
  32. def swap_blocks(
  33. src_kv_cache: torch.Tensor,
  34. dst_kv_cache: torch.Tensor,
  35. src_to_dst: Dict[int, int],
  36. ) -> None:
  37. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  38. @staticmethod
  39. def copy_blocks(
  40. kv_caches: List[torch.Tensor],
  41. src_to_dists: Dict[int, List[int]],
  42. ) -> None:
  43. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  44. @dataclass
  45. class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
  46. """Metadata for XFormersbackend.
  47. NOTE: Any python object stored here is not updated when it is
  48. cuda-graph replayed. If you have values that need to be changed
  49. dynamically, it should be stored in tensor. The tensor has to be
  50. updated from `CUDAGraphRunner.forward` API.
  51. """
  52. # Currently, input sequences can only contain all prompts
  53. # or all decoding. True if all sequences are prompts.
  54. is_prompt: bool
  55. # (batch_size,). The prompt length per sequence. None if it is a decoding.
  56. prompt_lens: Optional[List[int]]
  57. # prompt_lens stored as a tensor.
  58. prompt_lens_tensor: Optional[torch.Tensor]
  59. # NOTE: Definition of context_len, subquery_len, and seqlen.
  60. # |---------- N-1 iteration --------|
  61. # |---------------- N iteration ---------------------|
  62. # |- tokenA -|......................|-- newTokens ---|
  63. # |---------- context_len ----------|
  64. # |-------------------- seqlen ----------------------|
  65. # |- subquery_len -|
  66. # WARNING(sang): context_len has different definition depending on if it is
  67. # prefill vs decoding. When it is prefill, it doesn't include new tokens.
  68. # When it is for decoding, it includes a new token.
  69. # Maximum subquery length in the batch.
  70. max_subquery_len: Optional[int]
  71. # FIXME: It is for flash attn.
  72. # Maximum prompt length in the batch.
  73. max_prompt_len: Optional[int]
  74. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  75. # the batch, used to index into subquery. E.g., if the subquery length
  76. # is [4, 6], it is [0, 4, 10].
  77. subquery_start_loc: Optional[torch.Tensor]
  78. # FIXME: It is for flash attn.
  79. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  80. # the batch, used to index into sequence. E.g., if the sequence length is
  81. # [4, 6], it is [0, 4, 10].
  82. seq_start_loc: Optional[torch.Tensor]
  83. # Whether or not if cuda graph is enabled.
  84. # Cuda-graph is currently enabled for decoding only.
  85. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  86. use_cuda_graph: bool
  87. def __post_init__(self):
  88. # Set during the execution of the first attention op.
  89. # It is a list because it is needed to set per prompt
  90. # when alibi slopes is used. It is because of the limitation
  91. # from xformer API.
  92. # will not appear in the __repr__ and __init__
  93. self.attn_bias: Optional[List[AttentionBias]] = None
  94. class XFormersImpl(AttentionImpl):
  95. """
  96. If the input tensors contain prompt tokens, the layout is as follows:
  97. |<--------------- num_prefill_tokens ----------------->|
  98. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  99. Otherwise, the layout is as follows:
  100. |<----------------- num_decode_tokens ------------------>|
  101. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  102. Generation tokens can contain padding when cuda-graph is used.
  103. Currently, prompt tokens don't contain any padding.
  104. The prompts might have different lengths, while the generation tokens
  105. always have length 1.
  106. If chunked prefill is enabled, prefill tokens and decode tokens can be
  107. batched together in a flattened 1D query.
  108. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  109. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  110. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  111. padding between prefill and decode tokens.
  112. """
  113. def __init__(
  114. self,
  115. num_heads: int,
  116. head_size: int,
  117. scale: float,
  118. num_kv_heads: Optional[int] = None,
  119. alibi_slopes: Optional[List[float]] = None,
  120. sliding_window: Optional[int] = None,
  121. ) -> None:
  122. self.num_heads = num_heads
  123. self.head_size = head_size
  124. self.scale = float(scale)
  125. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  126. self.sliding_window = sliding_window
  127. if alibi_slopes is not None:
  128. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  129. self.alibi_slopes = alibi_slopes
  130. assert self.num_heads % self.num_kv_heads == 0
  131. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  132. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  133. if head_size not in suppored_head_sizes:
  134. raise ValueError(
  135. f"Head size {head_size} is not supported by PagedAttention. "
  136. f"Supported head sizes are: {suppored_head_sizes}.")
  137. def forward(
  138. self,
  139. query: torch.Tensor,
  140. key: torch.Tensor,
  141. value: torch.Tensor,
  142. kv_cache: Optional[torch.Tensor],
  143. attn_metadata: AttentionMetadata[XFormersMetadata],
  144. kv_scale: float,
  145. ) -> torch.Tensor:
  146. """Forward pass with xFormers and PagedAttention.
  147. Args:
  148. query: shape = [num_tokens, num_heads * head_size]
  149. key: shape = [num_tokens, num_kv_heads * head_size]
  150. value: shape = [num_tokens, num_kv_heads * head_size]
  151. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  152. attn_metadata: Metadata for attention.
  153. Returns:
  154. shape = [num_tokens, num_heads * head_size]
  155. """
  156. num_tokens, hidden_size = query.shape
  157. query = query.view(-1, self.num_heads, self.head_size)
  158. key = key.view(-1, self.num_kv_heads, self.head_size)
  159. value = value.view(-1, self.num_kv_heads, self.head_size)
  160. if kv_cache is not None:
  161. key_cache, value_cache = PagedAttention.split_kv_cache(
  162. kv_cache, self.num_kv_heads, self.head_size)
  163. # Reshape the input keys and values and store them in the cache.
  164. # If kv_cache is not provided, the new key and value tensors are
  165. # not cached. This happens during the initial memory profiling run.
  166. PagedAttention.write_to_paged_cache(key, value, key_cache,
  167. value_cache,
  168. attn_metadata.slot_mapping,
  169. attn_metadata.kv_cache_dtype,
  170. kv_scale)
  171. num_prefill_tokens = attn_metadata.num_prefill_tokens
  172. num_decode_tokens = attn_metadata.num_decode_tokens
  173. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  174. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  175. output = torch.empty_like(query)
  176. # Query for decode. KV is not needed because it is already cached.
  177. decode_query = query[num_prefill_tokens:]
  178. # QKV for prefill.
  179. query = query[:num_prefill_tokens]
  180. key = key[:num_prefill_tokens]
  181. value = value[:num_prefill_tokens]
  182. assert query.shape[0] == num_prefill_tokens
  183. assert decode_query.shape[0] == num_decode_tokens
  184. if prefill_meta := attn_metadata.prefill_metadata:
  185. # Prompt run.
  186. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  187. # normal attention.
  188. # block tables are empty if the prompt does not have a cached
  189. # prefix.
  190. out = self._run_memory_efficient_xformers_forward(
  191. query, key, value, prefill_meta)
  192. assert out.shape == output[:num_prefill_tokens].shape
  193. output[:num_prefill_tokens] = out
  194. else:
  195. # prefix-enabled attention
  196. # TODO: this triton kernel has regression issue (broke) to
  197. # deal with different data types between KV and FP8 KV cache,
  198. # to be addressed separately.
  199. out = PagedAttention.forward_prefix(
  200. query,
  201. key,
  202. value,
  203. key_cache,
  204. value_cache,
  205. prefill_meta.block_tables,
  206. prefill_meta.subquery_start_loc,
  207. prefill_meta.prompt_lens_tensor,
  208. prefill_meta.context_lens,
  209. prefill_meta.max_subquery_len,
  210. self.alibi_slopes,
  211. )
  212. assert output[:num_prefill_tokens].shape == out.shape
  213. output[:num_prefill_tokens] = out
  214. if decode_meta := attn_metadata.decode_metadata:
  215. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  216. decode_query,
  217. key_cache,
  218. value_cache,
  219. decode_meta.block_tables,
  220. decode_meta.context_lens,
  221. decode_meta.max_context_len,
  222. attn_metadata.kv_cache_dtype,
  223. self.num_kv_heads,
  224. self.scale,
  225. self.alibi_slopes,
  226. kv_scale,
  227. )
  228. # Reshape the output tensor.
  229. return output.view(-1, self.num_heads * self.head_size)
  230. def _run_memory_efficient_xformers_forward(
  231. self,
  232. query: torch.Tensor,
  233. key: torch.Tensor,
  234. value: torch.Tensor,
  235. attn_metadata: XFormersMetadata,
  236. ) -> torch.Tensor:
  237. """Attention for 1D query of multiple prompts. Multiple prompt
  238. tokens are flattened in to `query` input.
  239. See https://facebookresearch.github.io/xformers/components/ops.html
  240. for API spec.
  241. Args:
  242. output: shape = [num_prefill_tokens, num_heads, head_size]
  243. query: shape = [num_prefill_tokens, num_heads, head_size]
  244. key: shape = [num_prefill_tokens, num_kv_heads, head_size]
  245. value: shape = [num_prefill_tokens, num_kv_heads, head_size]
  246. attn_metadata: Metadata for attention.
  247. """
  248. assert attn_metadata.prompt_lens is not None
  249. original_query = query
  250. if self.num_kv_heads != self.num_heads:
  251. # GQA/MQA requires the shape [B, M, G, H, K].
  252. # Note that the output also has the same shape (which is different
  253. # from a spec from the doc).
  254. query = query.view(query.shape[0], self.num_kv_heads,
  255. self.num_queries_per_kv, query.shape[-1])
  256. key = key[:, :,
  257. None, :].expand(key.shape[0], self.num_kv_heads,
  258. self.num_queries_per_kv, key.shape[-1])
  259. value = value[:, :,
  260. None, :].expand(value.shape[0], self.num_kv_heads,
  261. self.num_queries_per_kv,
  262. value.shape[-1])
  263. # Set attention bias if not provided. This typically happens at
  264. # the very attention layer of every iteration.
  265. # FIXME: This is a hack.
  266. if attn_metadata.attn_bias is None:
  267. if self.alibi_slopes is None:
  268. attn_bias = BlockDiagonalCausalMask.from_seqlens(
  269. attn_metadata.prompt_lens)
  270. if self.sliding_window is not None:
  271. attn_bias = attn_bias.make_local_attention(
  272. self.sliding_window)
  273. attn_metadata.attn_bias = [attn_bias]
  274. else:
  275. attn_metadata.attn_bias = _make_alibi_bias(
  276. self.alibi_slopes, self.num_kv_heads, query.dtype,
  277. attn_metadata.prompt_lens)
  278. # No alibi slopes.
  279. # TODO: Too many view operations. Let's try to reduce
  280. # them in the future for code readability.
  281. if self.alibi_slopes is None:
  282. # Add the batch dimension.
  283. query = query.unsqueeze(0)
  284. key = key.unsqueeze(0)
  285. value = value.unsqueeze(0)
  286. out = xops.memory_efficient_attention_forward(
  287. query,
  288. key,
  289. value,
  290. attn_bias=attn_metadata.attn_bias[0],
  291. p=0.0,
  292. scale=self.scale)
  293. return out.view_as(original_query)
  294. # Attention with alibi slopes.
  295. # FIXME: Because xformers does not support dynamic sequence
  296. # lengths with custom attention bias, we process each prompt one by
  297. # one. This is inefficient, especially when we have many short prompts.
  298. output = torch.empty_like(original_query)
  299. start = 0
  300. for i, prompt_len in enumerate(attn_metadata.prompt_lens):
  301. end = start + prompt_len
  302. out = xops.memory_efficient_attention_forward(
  303. query[None, start:end],
  304. key[None, start:end],
  305. value[None, start:end],
  306. attn_bias=attn_metadata.attn_bias[i],
  307. p=0.0,
  308. scale=self.scale)
  309. # TODO: Unnecessary copy. Optimize.
  310. output[start:end].copy_(out.view_as(original_query[start:end]))
  311. start += prompt_len
  312. return output
  313. def _make_alibi_bias(
  314. alibi_slopes: torch.Tensor,
  315. num_kv_heads: int,
  316. dtype: torch.dtype,
  317. prompt_lens: List[int],
  318. ) -> LowerTriangularMaskWithTensorBias:
  319. attn_biases = []
  320. for prompt_len in prompt_lens:
  321. bias = torch.arange(prompt_len, dtype=dtype)
  322. # NOTE: HF uses
  323. # `bias = bias[None, :].repeat(prompt_len, 1)`
  324. # here. We find that both biases give the same results, but
  325. # the bias below more accurately follows the original ALiBi
  326. # paper.
  327. # Calculate a matrix where each element represents ith element- jth
  328. # element.
  329. bias = bias[None, :] - bias[:, None]
  330. padded_len = (prompt_len + 7) // 8 * 8
  331. num_heads = alibi_slopes.shape[0]
  332. bias = torch.empty(
  333. 1, # batch size
  334. num_heads,
  335. prompt_len,
  336. padded_len,
  337. device=alibi_slopes.device,
  338. dtype=dtype,
  339. )[:, :, :, :prompt_len].copy_(bias)
  340. bias.mul_(alibi_slopes[:, None, None])
  341. if num_heads != num_kv_heads:
  342. bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
  343. attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
  344. return attn_biases