xformers.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. """Attention layer with xFormers and PagedAttention."""
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, List, Optional, Tuple, Type
  4. import torch
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import (AttentionBias,
  7. BlockDiagonalCausalMask,
  8. LowerTriangularMaskWithTensorBias)
  9. from aphrodite.attention.backends.abstract import (AttentionBackend,
  10. AttentionImpl,
  11. AttentionMetadata)
  12. from aphrodite.attention.ops.paged_attn import (PagedAttention,
  13. PagedAttentionMetadata)
  14. class XFormersBackend(AttentionBackend):
  15. @staticmethod
  16. def get_name() -> str:
  17. return "xformers"
  18. @staticmethod
  19. def get_impl_cls() -> Type["XFormersImpl"]:
  20. return XFormersImpl
  21. @staticmethod
  22. def make_metadata(*args, **kwargs) -> "XFormersMetadata":
  23. return XFormersMetadata(*args, **kwargs)
  24. @staticmethod
  25. def get_kv_cache_shape(
  26. num_blocks: int,
  27. block_size: int,
  28. num_kv_heads: int,
  29. head_size: int,
  30. ) -> Tuple[int, ...]:
  31. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  32. num_kv_heads, head_size)
  33. @staticmethod
  34. def swap_blocks(
  35. src_kv_cache: torch.Tensor,
  36. dst_kv_cache: torch.Tensor,
  37. src_to_dst: Dict[int, int],
  38. ) -> None:
  39. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  40. @staticmethod
  41. def copy_blocks(
  42. kv_caches: List[torch.Tensor],
  43. src_to_dists: torch.Tensor,
  44. ) -> None:
  45. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  46. @dataclass
  47. class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
  48. """Metadata for XFormersbackend.
  49. NOTE: Any python object stored here is not updated when it is
  50. cuda-graph replayed. If you have values that need to be changed
  51. dynamically, it should be stored in tensor. The tensor has to be
  52. updated from `CUDAGraphRunner.forward` API.
  53. """
  54. # (batch_size,). The sequence length per sequence. Sequence length means
  55. # the computed tokens + new tokens None if it is a decoding.
  56. seq_lens: Optional[List[int]]
  57. # seq_lens stored as a tensor.
  58. seq_lens_tensor: Optional[torch.Tensor]
  59. # |---------- N-1 iteration --------|
  60. # |---------------- N iteration ---------------------|
  61. # |- tokenA -|......................|-- newTokens ---|
  62. # |---------- context_len ----------|
  63. # |-------------------- seq_len ----------------------|
  64. # |-- query_len ---|
  65. # Maximum query length in the batch. None for decoding.
  66. max_query_len: Optional[int]
  67. # FIXME: It is for flash attn.
  68. # Maximum sequence length among prefill batch. 0 if there are decoding
  69. # requests only.
  70. max_prefill_seq_len: int
  71. # Maximum sequence length among decode batch. 0 if there are prefill
  72. # requests only.
  73. max_decode_seq_len: int
  74. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  75. # the batch, used to index into subquery. E.g., if the subquery length
  76. # is [4, 6], it is [0, 4, 10].
  77. query_start_loc: Optional[torch.Tensor]
  78. # FIXME: It is for flash attn.
  79. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  80. # the batch, used to index into sequence. E.g., if the sequence length is
  81. # [4, 6], it is [0, 4, 10].
  82. seq_start_loc: Optional[torch.Tensor]
  83. # (batch_size,) A tensor of context lengths (tokens that are computed
  84. # so far).
  85. context_lens_tensor: Optional[torch.Tensor]
  86. # Whether or not if cuda graph is enabled.
  87. # Cuda-graph is currently enabled for decoding only.
  88. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  89. use_cuda_graph: bool
  90. _cached_prefill_metadata: Optional["XFormersMetadata"] = None
  91. _cached_decode_metadata: Optional["XFormersMetadata"] = None
  92. def __post_init__(self):
  93. # Set during the execution of the first attention op.
  94. # It is a list because it is needed to set per prompt
  95. # when alibi slopes is used. It is because of the limitation
  96. # from xformer API.
  97. # will not appear in the __repr__ and __init__
  98. self.attn_bias: Optional[List[AttentionBias]] = None
  99. @property
  100. def prefill_metadata(self) -> Optional["XFormersMetadata"]:
  101. if self.num_prefills == 0:
  102. return None
  103. if self._cached_prefill_metadata is not None:
  104. return self._cached_prefill_metadata
  105. assert self.seq_lens is not None
  106. assert self.seq_lens_tensor is not None
  107. assert self.query_start_loc is not None
  108. assert self.context_lens_tensor is not None
  109. assert self.block_tables is not None
  110. self._cached_prefill_metadata = XFormersMetadata(
  111. num_prefills=self.num_prefills,
  112. num_prefill_tokens=self.num_prefill_tokens,
  113. num_decode_tokens=0,
  114. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  115. seq_lens=self.seq_lens[:self.num_prefills],
  116. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  117. max_query_len=self.max_query_len,
  118. max_prefill_seq_len=self.max_prefill_seq_len,
  119. max_decode_seq_len=0,
  120. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  121. seq_start_loc=None,
  122. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  123. block_tables=self.block_tables[:self.num_prefills],
  124. use_cuda_graph=False,
  125. )
  126. return self._cached_prefill_metadata
  127. @property
  128. def decode_metadata(self) -> Optional["XFormersMetadata"]:
  129. if self.num_decode_tokens == 0:
  130. return None
  131. if self._cached_decode_metadata is not None:
  132. return self._cached_decode_metadata
  133. assert self.block_tables is not None
  134. assert self.seq_lens_tensor is not None
  135. self._cached_decode_metadata = XFormersMetadata(
  136. num_prefills=0,
  137. num_prefill_tokens=0,
  138. num_decode_tokens=self.num_decode_tokens,
  139. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  140. seq_lens=None,
  141. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  142. max_query_len=None,
  143. max_prefill_seq_len=0,
  144. max_decode_seq_len=self.max_decode_seq_len,
  145. query_start_loc=None,
  146. seq_start_loc=None,
  147. context_lens_tensor=None,
  148. block_tables=self.block_tables[self.num_prefills:],
  149. use_cuda_graph=self.use_cuda_graph,
  150. )
  151. return self._cached_decode_metadata
  152. class XFormersImpl(AttentionImpl[XFormersMetadata]):
  153. """
  154. If the input tensors contain prompt tokens, the layout is as follows:
  155. |<--------------- num_prefill_tokens ----------------->|
  156. |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
  157. Otherwise, the layout is as follows:
  158. |<----------------- num_decode_tokens ------------------>|
  159. |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
  160. Generation tokens can contain padding when cuda-graph is used.
  161. Currently, prompt tokens don't contain any padding.
  162. The prompts might have different lengths, while the generation tokens
  163. always have length 1.
  164. If chunked prefill is enabled, prefill tokens and decode tokens can be
  165. batched together in a flattened 1D query.
  166. |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
  167. |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
  168. Currently, cuda graph is disabled for chunked prefill, meaning there's no
  169. padding between prefill and decode tokens.
  170. """
  171. def __init__(
  172. self,
  173. num_heads: int,
  174. head_size: int,
  175. scale: float,
  176. num_kv_heads: int,
  177. alibi_slopes: Optional[List[float]],
  178. sliding_window: Optional[int],
  179. kv_cache_dtype: str,
  180. blocksparse_params: Optional[Dict[str, Any]] = None,
  181. ) -> None:
  182. assert blocksparse_params is None, ValueError(
  183. "XFormers does not support block-sparse attention.")
  184. self.num_heads = num_heads
  185. self.head_size = head_size
  186. self.scale = float(scale)
  187. self.num_kv_heads = num_kv_heads
  188. if alibi_slopes is not None:
  189. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  190. self.alibi_slopes = alibi_slopes
  191. self.sliding_window = sliding_window
  192. self.kv_cache_dtype = kv_cache_dtype
  193. assert self.num_heads % self.num_kv_heads == 0
  194. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  195. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  196. if head_size not in suppored_head_sizes:
  197. raise ValueError(
  198. f"Head size {head_size} is not supported by PagedAttention. "
  199. f"Supported head sizes are: {suppored_head_sizes}.")
  200. def forward(
  201. self,
  202. query: torch.Tensor,
  203. key: torch.Tensor,
  204. value: torch.Tensor,
  205. kv_cache: Optional[torch.Tensor],
  206. attn_metadata: "XFormersMetadata",
  207. kv_scale: float = 1.0,
  208. ) -> torch.Tensor:
  209. """Forward pass with xFormers and PagedAttention.
  210. Args:
  211. query: shape = [num_tokens, num_heads * head_size]
  212. key: shape = [num_tokens, num_kv_heads * head_size]
  213. value: shape = [num_tokens, num_kv_heads * head_size]
  214. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  215. attn_metadata: Metadata for attention.
  216. Returns:
  217. shape = [num_tokens, num_heads * head_size]
  218. """
  219. query = query.view(-1, self.num_heads, self.head_size)
  220. key = key.view(-1, self.num_kv_heads, self.head_size)
  221. value = value.view(-1, self.num_kv_heads, self.head_size)
  222. if kv_cache is not None:
  223. key_cache, value_cache = PagedAttention.split_kv_cache(
  224. kv_cache, self.num_kv_heads, self.head_size)
  225. # Reshape the input keys and values and store them in the cache.
  226. # If kv_cache is not provided, the new key and value tensors are
  227. # not cached. This happens during the initial memory profiling run.
  228. PagedAttention.write_to_paged_cache(key, value, key_cache,
  229. value_cache,
  230. attn_metadata.slot_mapping,
  231. self.kv_cache_dtype, kv_scale)
  232. num_prefill_tokens = attn_metadata.num_prefill_tokens
  233. num_decode_tokens = attn_metadata.num_decode_tokens
  234. assert key.shape[0] == num_prefill_tokens + num_decode_tokens
  235. assert value.shape[0] == num_prefill_tokens + num_decode_tokens
  236. output = torch.empty_like(query)
  237. # Query for decode. KV is not needed because it is already cached.
  238. decode_query = query[num_prefill_tokens:]
  239. # QKV for prefill.
  240. query = query[:num_prefill_tokens]
  241. key = key[:num_prefill_tokens]
  242. value = value[:num_prefill_tokens]
  243. assert query.shape[0] == num_prefill_tokens
  244. assert decode_query.shape[0] == num_decode_tokens
  245. if prefill_meta := attn_metadata.prefill_metadata:
  246. # Prompt run.
  247. if kv_cache is None or prefill_meta.block_tables.numel() == 0:
  248. # normal attention.
  249. # block tables are empty if the prompt does not have a cached
  250. # prefix.
  251. out = self._run_memory_efficient_xformers_forward(
  252. query, key, value, prefill_meta)
  253. assert out.shape == output[:num_prefill_tokens].shape
  254. output[:num_prefill_tokens] = out
  255. else:
  256. # prefix-enabled attention
  257. # TODO: this triton kernel has regression issue (broke) to
  258. # deal with different data types between KV and FP8 KV cache,
  259. # to be addressed separately.
  260. out = PagedAttention.forward_prefix(
  261. query,
  262. key,
  263. value,
  264. key_cache,
  265. value_cache,
  266. prefill_meta.block_tables,
  267. prefill_meta.query_start_loc,
  268. prefill_meta.seq_lens_tensor,
  269. prefill_meta.context_lens_tensor,
  270. prefill_meta.max_query_len,
  271. self.alibi_slopes,
  272. self.sliding_window,
  273. )
  274. assert output[:num_prefill_tokens].shape == out.shape
  275. output[:num_prefill_tokens] = out
  276. if decode_meta := attn_metadata.decode_metadata:
  277. output[num_prefill_tokens:] = PagedAttention.forward_decode(
  278. decode_query,
  279. key_cache,
  280. value_cache,
  281. decode_meta.block_tables,
  282. decode_meta.seq_lens_tensor,
  283. decode_meta.max_decode_seq_len,
  284. self.kv_cache_dtype,
  285. self.num_kv_heads,
  286. self.scale,
  287. self.alibi_slopes,
  288. kv_scale,
  289. )
  290. # Reshape the output tensor.
  291. return output.view(-1, self.num_heads * self.head_size)
  292. def _run_memory_efficient_xformers_forward(
  293. self,
  294. query: torch.Tensor,
  295. key: torch.Tensor,
  296. value: torch.Tensor,
  297. attn_metadata: XFormersMetadata,
  298. ) -> torch.Tensor:
  299. """Attention for 1D query of multiple prompts. Multiple prompt
  300. tokens are flattened in to `query` input.
  301. See https://facebookresearch.github.io/xformers/components/ops.html
  302. for API spec.
  303. Args:
  304. output: shape = [num_prefill_tokens, num_heads, head_size]
  305. query: shape = [num_prefill_tokens, num_heads, head_size]
  306. key: shape = [num_prefill_tokens, num_kv_heads, head_size]
  307. value: shape = [num_prefill_tokens, num_kv_heads, head_size]
  308. attn_metadata: Metadata for attention.
  309. """
  310. assert attn_metadata.seq_lens is not None
  311. original_query = query
  312. if self.num_kv_heads != self.num_heads:
  313. # GQA/MQA requires the shape [B, M, G, H, K].
  314. # Note that the output also has the same shape (which is different
  315. # from a spec from the doc).
  316. query = query.view(query.shape[0], self.num_kv_heads,
  317. self.num_queries_per_kv, query.shape[-1])
  318. key = key[:, :,
  319. None, :].expand(key.shape[0], self.num_kv_heads,
  320. self.num_queries_per_kv, key.shape[-1])
  321. value = value[:, :,
  322. None, :].expand(value.shape[0], self.num_kv_heads,
  323. self.num_queries_per_kv,
  324. value.shape[-1])
  325. # Set attention bias if not provided. This typically happens at
  326. # the very attention layer of every iteration.
  327. # FIXME: This is a hack.
  328. if attn_metadata.attn_bias is None:
  329. if self.alibi_slopes is None:
  330. attn_bias = BlockDiagonalCausalMask.from_seqlens(
  331. attn_metadata.seq_lens)
  332. if self.sliding_window is not None:
  333. attn_bias = attn_bias.make_local_attention(
  334. self.sliding_window)
  335. attn_metadata.attn_bias = [attn_bias]
  336. else:
  337. attn_metadata.attn_bias = _make_alibi_bias(
  338. self.alibi_slopes, self.num_kv_heads, query.dtype,
  339. attn_metadata.seq_lens)
  340. # No alibi slopes.
  341. # TODO: Too many view operations. Let's try to reduce
  342. # them in the future for code readability.
  343. if self.alibi_slopes is None:
  344. # Add the batch dimension.
  345. query = query.unsqueeze(0)
  346. key = key.unsqueeze(0)
  347. value = value.unsqueeze(0)
  348. out = xops.memory_efficient_attention_forward(
  349. query,
  350. key,
  351. value,
  352. attn_bias=attn_metadata.attn_bias[0],
  353. p=0.0,
  354. scale=self.scale)
  355. return out.view_as(original_query)
  356. # Attention with alibi slopes.
  357. # FIXME: Because xformers does not support dynamic sequence
  358. # lengths with custom attention bias, we process each prompt one by
  359. # one. This is inefficient, especially when we have many short prompts.
  360. output = torch.empty_like(original_query)
  361. start = 0
  362. for i, seq_len in enumerate(attn_metadata.seq_lens):
  363. end = start + seq_len
  364. out = xops.memory_efficient_attention_forward(
  365. query[None, start:end],
  366. key[None, start:end],
  367. value[None, start:end],
  368. attn_bias=attn_metadata.attn_bias[i],
  369. p=0.0,
  370. scale=self.scale)
  371. # TODO: Unnecessary copy. Optimize.
  372. output[start:end].copy_(out.view_as(original_query[start:end]))
  373. start += seq_len
  374. return output
  375. def _make_alibi_bias(
  376. alibi_slopes: torch.Tensor,
  377. num_kv_heads: int,
  378. dtype: torch.dtype,
  379. seq_lens: List[int],
  380. ) -> LowerTriangularMaskWithTensorBias:
  381. attn_biases = []
  382. for seq_len in seq_lens:
  383. bias = torch.arange(seq_len, dtype=dtype)
  384. # NOTE(zhuohan): HF uses
  385. # `bias = bias[None, :].repeat(seq_len, 1)`
  386. # here. We find that both biases give the same results, but
  387. # the bias below more accurately follows the original ALiBi
  388. # paper.
  389. # Calculate a matrix where each element represents ith element- jth
  390. # element.
  391. bias = bias[None, :] - bias[:, None]
  392. padded_len = (seq_len + 7) // 8 * 8
  393. num_heads = alibi_slopes.shape[0]
  394. bias = torch.empty(
  395. 1, # batch size
  396. num_heads,
  397. seq_len,
  398. padded_len,
  399. device=alibi_slopes.device,
  400. dtype=dtype,
  401. )[:, :, :, :seq_len].copy_(bias)
  402. bias.mul_(alibi_slopes[:, None, None])
  403. if num_heads != num_kv_heads:
  404. bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
  405. attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
  406. return attn_biases