1
0

xformers.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """Attention layer with xFormers and PagedAttention."""
  2. from dataclasses import dataclass
  3. from typing import Dict, List, Optional, Tuple, Type
  4. import torch
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import (
  7. AttentionBias,
  8. BlockDiagonalCausalMask,
  9. LowerTriangularMaskWithTensorBias,
  10. )
  11. from aphrodite.attention.backends.abstract import (
  12. AttentionBackend,
  13. AttentionImpl,
  14. AttentionMetadata,
  15. AttentionMetadataPerStage,
  16. )
  17. from aphrodite.attention.ops.paged_attn import (
  18. PagedAttention,
  19. PagedAttentionMetadata,
  20. )
  21. class XFormersBackend(AttentionBackend):
  22. @staticmethod
  23. def get_impl_cls() -> Type["XFormersImpl"]:
  24. return XFormersImpl
  25. @staticmethod
  26. def make_metadata(*args, **kwargs) -> "XFormersMetadata":
  27. return XFormersMetadata(*args, **kwargs)
  28. @staticmethod
  29. def get_kv_cache_shape(
  30. num_blocks: int,
  31. block_size: int,
  32. num_kv_heads: int,
  33. head_size: int,
  34. ) -> Tuple[int, ...]:
  35. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  36. num_kv_heads, head_size)
  37. @staticmethod
  38. def swap_blocks(
  39. src_kv_cache: torch.Tensor,
  40. dst_kv_cache: torch.Tensor,
  41. src_to_dst: Dict[int, int],
  42. ) -> None:
  43. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  44. @staticmethod
  45. def copy_blocks(
  46. kv_caches: List[torch.Tensor],
  47. src_to_dists: Dict[int, List[int]],
  48. ) -> None:
  49. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  50. @dataclass
  51. class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
  52. """Metadata for XFormersbackend.
  53. NOTE: Any python object stored here is not updated when it is
  54. cuda-graph replayed. If you have values that need to be changed
  55. dynamically, it should be stored in tensor. The tensor has to be
  56. updated from `CUDAGraphRunner.forward` API.
  57. """
  58. # Currently, input sequences can only contain all prompts
  59. # or all decoding. True if all sequences are prompts.
  60. is_prompt: bool
  61. # (batch_size,). The prompt length per sequence. None if it is a decoding.
  62. prompt_lens: Optional[List[int]]
  63. # prompt_lens stored as a tensor.
  64. prompt_lens_tensor: Optional[torch.Tensor]
  65. # NOTE: Definition of context_len, subquery_len, and seqlen.
  66. # |---------- N-1 iteration --------|
  67. # |---------------- N iteration ---------------------|
  68. # |- tokenA -|......................|-- newTokens ---|
  69. # |---------- context_len ----------|
  70. # |-------------------- seqlen ----------------------|
  71. # |- subquery_len -|
  72. # WARNING(sang): context_len has different definition depending on if it is
  73. # prefill vs decoding. When it is prefill, it doesn't include new tokens.
  74. # When it is for decoding, it includes a new token.
  75. # Maximum subquery length in the batch.
  76. max_subquery_len: Optional[int]
  77. # FIXME: It is for flash attn.
  78. # Maximum prompt length in the batch.
  79. max_prompt_len: Optional[int]
  80. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  81. # the batch, used to index into subquery. E.g., if the subquery length
  82. # is [4, 6], it is [0, 4, 10].
  83. subquery_start_loc: Optional[torch.Tensor]
  84. # FIXME: It is for flash attn.
  85. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  86. # the batch, used to index into sequence. E.g., if the sequence length is
  87. # [4, 6], it is [0, 4, 10].
  88. seq_start_loc: Optional[torch.Tensor]
  89. # Whether or not if cuda graph is enabled.
  90. # Cuda-graph is currently enabled for decoding only.
  91. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  92. use_cuda_graph: bool
  93. def __post_init__(self):
  94. # Set during the execution of the first attention op.
  95. # It is a list because it is needed to set per prompt
  96. # when alibi slopes is used. It is because of the limitation
  97. # from xformer API.
  98. # will not appear in the __repr__ and __init__
  99. self.attn_bias: Optional[List[AttentionBias]] = None
  100. class XFormersImpl(AttentionImpl):
  101. """
  102. If the input tensors contain prompt tokens, the layout is as follows:
  103. |<--------------- num_prefill_tokens ----------------->|
  104. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  105. Otherwise, the layout is as follows:
  106. |<----------------- num_decode_tokens ------------------>|
  107. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  108. Generation tokens can contain padding when cuda-graph is used.
  109. Currently, prompt tokens don't contain any padding.
  110. The prompts might have different lengths, while the generation tokens
  111. always have length 1.
  112. If chunked prefill is enabled, prefill tokens and decode tokens can be
  113. batched together in a flattened 1D query.
  114. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  115. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  116. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  117. padding between prefill and decode tokens.
  118. """
  119. def __init__(
  120. self,
  121. num_heads: int,
  122. head_size: int,
  123. scale: float,
  124. num_kv_heads: Optional[int] = None,
  125. alibi_slopes: Optional[List[float]] = None,
  126. sliding_window: Optional[int] = None,
  127. ) -> None:
  128. self.num_heads = num_heads
  129. self.head_size = head_size
  130. self.scale = float(scale)
  131. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  132. self.sliding_window = sliding_window
  133. if alibi_slopes is not None:
  134. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  135. self.alibi_slopes = alibi_slopes
  136. assert self.num_heads % self.num_kv_heads == 0
  137. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  138. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  139. if head_size not in suppored_head_sizes:
  140. raise ValueError(
  141. f"Head size {head_size} is not supported by PagedAttention. "
  142. f"Supported head sizes are: {suppored_head_sizes}.")
  143. def forward(
  144. self,
  145. query: torch.Tensor,
  146. key: torch.Tensor,
  147. value: torch.Tensor,
  148. kv_cache: Optional[torch.Tensor],
  149. attn_metadata: AttentionMetadata[XFormersMetadata],
  150. kv_scale: float,
  151. ) -> torch.Tensor:
  152. """Forward pass with xFormers and PagedAttention.
  153. Args:
  154. query: shape = [num_tokens, num_heads * head_size]
  155. key: shape = [num_tokens, num_kv_heads * head_size]
  156. value: shape = [num_tokens, num_kv_heads * head_size]
  157. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  158. attn_metadata: Metadata for attention.
  159. Returns:
  160. shape = [num_tokens, num_heads * head_size]
  161. """
  162. num_tokens, hidden_size = query.shape
  163. query = query.view(-1, self.num_heads, self.head_size)
  164. key = key.view(-1, self.num_kv_heads, self.head_size)
  165. value = value.view(-1, self.num_kv_heads, self.head_size)
  166. if kv_cache is not None:
  167. key_cache, value_cache = PagedAttention.split_kv_cache(
  168. kv_cache, self.num_kv_heads, self.head_size)
  169. # Reshape the input keys and values and store them in the cache.
  170. # If kv_cache is not provided, the new key and value tensors are
  171. # not cached. This happens during the initial memory profiling run.
  172. PagedAttention.write_to_paged_cache(key, value, key_cache,
  173. value_cache,
  174. attn_metadata.slot_mapping,
  175. attn_metadata.kv_cache_dtype,
  176. kv_scale)
  177. num_prefill_tokens = attn_metadata.num_prefill_tokens
  178. num_decode_tokens = attn_metadata.num_decode_tokens
  179. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  180. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  181. output = torch.empty_like(query)
  182. # Query for decode. KV is not needed because it is already cached.
  183. decode_query = query[num_prefill_tokens:]
  184. # QKV for prefill.
  185. query = query[:num_prefill_tokens]
  186. key = key[:num_prefill_tokens]
  187. value = value[:num_prefill_tokens]
  188. assert query.shape[0] == num_prefill_tokens
  189. assert decode_query.shape[0] == num_decode_tokens
  190. if prefill_meta := attn_metadata.prefill_metadata:
  191. # Prompt run.
  192. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  193. # normal attention.
  194. # block tables are empty if the prompt does not have a cached
  195. # prefix.
  196. out = self._run_memory_efficient_xformers_forward(
  197. query, key, value, prefill_meta)
  198. assert out.shape == output[:num_prefill_tokens].shape
  199. output[:num_prefill_tokens] = out
  200. else:
  201. # prefix-enabled attention
  202. # TODO: this triton kernel has regression issue (broke) to
  203. # deal with different data types between KV and FP8 KV cache,
  204. # to be addressed separately.
  205. out = PagedAttention.forward_prefix(
  206. query,
  207. key,
  208. value,
  209. key_cache,
  210. value_cache,
  211. prefill_meta.block_tables,
  212. prefill_meta.subquery_start_loc,
  213. prefill_meta.prompt_lens_tensor,
  214. prefill_meta.context_lens,
  215. prefill_meta.max_subquery_len,
  216. self.alibi_slopes,
  217. )
  218. assert output[:num_prefill_tokens].shape == out.shape
  219. output[:num_prefill_tokens] = out
  220. if decode_meta := attn_metadata.decode_metadata:
  221. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  222. decode_query,
  223. key_cache,
  224. value_cache,
  225. decode_meta.block_tables,
  226. decode_meta.context_lens,
  227. decode_meta.max_context_len,
  228. attn_metadata.kv_cache_dtype,
  229. self.num_kv_heads,
  230. self.scale,
  231. self.alibi_slopes,
  232. kv_scale,
  233. )
  234. # Reshape the output tensor.
  235. return output.view(-1, self.num_heads * self.head_size)
  236. def _run_memory_efficient_xformers_forward(
  237. self,
  238. query: torch.Tensor,
  239. key: torch.Tensor,
  240. value: torch.Tensor,
  241. attn_metadata: XFormersMetadata,
  242. ) -> torch.Tensor:
  243. """Attention for 1D query of multiple prompts. Multiple prompt
  244. tokens are flattened in to `query` input.
  245. See https://facebookresearch.github.io/xformers/components/ops.html
  246. for API spec.
  247. Args:
  248. output: shape = [num_prefill_tokens, num_heads, head_size]
  249. query: shape = [num_prefill_tokens, num_heads, head_size]
  250. key: shape = [num_prefill_tokens, num_kv_heads, head_size]
  251. value: shape = [num_prefill_tokens, num_kv_heads, head_size]
  252. attn_metadata: Metadata for attention.
  253. """
  254. original_query = query
  255. if self.num_kv_heads != self.num_heads:
  256. # GQA/MQA requires the shape [B, M, G, H, K].
  257. # Note that the output also has the same shape (which is different
  258. # from a spec from the doc).
  259. query = query.view(query.shape[0], self.num_kv_heads,
  260. self.num_queries_per_kv, query.shape[-1])
  261. key = key[:, :,
  262. None, :].expand(key.shape[0], self.num_kv_heads,
  263. self.num_queries_per_kv, key.shape[-1])
  264. value = value[:, :,
  265. None, :].expand(value.shape[0], self.num_kv_heads,
  266. self.num_queries_per_kv,
  267. value.shape[-1])
  268. # Set attention bias if not provided. This typically happens at
  269. # the very attention layer of every iteration.
  270. # FIXME: This is a hack.
  271. if attn_metadata.attn_bias is None:
  272. if self.alibi_slopes is None:
  273. attn_bias = BlockDiagonalCausalMask.from_seqlens(
  274. attn_metadata.prompt_lens)
  275. if self.sliding_window is not None:
  276. attn_bias = attn_bias.make_local_attention(
  277. self.sliding_window)
  278. attn_metadata.attn_bias = [attn_bias]
  279. else:
  280. attn_metadata.attn_bias = _make_alibi_bias(
  281. self.alibi_slopes, self.num_kv_heads, query.dtype,
  282. attn_metadata.prompt_lens)
  283. # No alibi slopes.
  284. # TODO: Too many view operations. Let's try to reduce
  285. # them in the future for code readability.
  286. if self.alibi_slopes is None:
  287. # Add the batch dimension.
  288. query = query.unsqueeze(0)
  289. key = key.unsqueeze(0)
  290. value = value.unsqueeze(0)
  291. out = xops.memory_efficient_attention_forward(
  292. query,
  293. key,
  294. value,
  295. attn_bias=attn_metadata.attn_bias[0],
  296. p=0.0,
  297. scale=self.scale)
  298. return out.view_as(original_query)
  299. # Attention with alibi slopes.
  300. # FIXME: Because xformers does not support dynamic sequence
  301. # lengths with custom attention bias, we process each prompt one by
  302. # one. This is inefficient, especially when we have many short prompts.
  303. output = torch.empty_like(original_query)
  304. start = 0
  305. for i, prompt_len in enumerate(attn_metadata.prompt_lens):
  306. end = start + prompt_len
  307. out = xops.memory_efficient_attention_forward(
  308. query[None, start:end],
  309. key[None, start:end],
  310. value[None, start:end],
  311. attn_bias=attn_metadata.attn_bias[i],
  312. p=0.0,
  313. scale=self.scale)
  314. # TODO: Unnecessary copy. Optimize.
  315. output[start:end].copy_(out.view_as(original_query[start:end]))
  316. start += prompt_len
  317. return output
  318. def _make_alibi_bias(
  319. alibi_slopes: torch.Tensor,
  320. num_kv_heads: int,
  321. dtype: torch.dtype,
  322. prompt_lens: List[int],
  323. ) -> LowerTriangularMaskWithTensorBias:
  324. attn_biases = []
  325. for prompt_len in prompt_lens:
  326. bias = torch.arange(prompt_len, dtype=dtype)
  327. # NOTE: HF uses
  328. # `bias = bias[None, :].repeat(prompt_len, 1)`
  329. # here. We find that both biases give the same results, but
  330. # the bias below more accurately follows the original ALiBi
  331. # paper.
  332. # Calculate a matrix where each element represents ith element- jth
  333. # element.
  334. bias = bias[None, :] - bias[:, None]
  335. padded_len = (prompt_len + 7) // 8 * 8
  336. num_heads = alibi_slopes.shape[0]
  337. bias = torch.empty(
  338. 1, # batch size
  339. num_heads,
  340. prompt_len,
  341. padded_len,
  342. device=alibi_slopes.device,
  343. dtype=dtype,
  344. )[:, :, :, :prompt_len].copy_(bias)
  345. bias.mul_(alibi_slopes[:, None, None])
  346. if num_heads != num_kv_heads:
  347. bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
  348. attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
  349. return attn_biases