attention.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. """Multi-head attention."""
  2. from typing import List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from xformers import ops as xops
  6. from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
  7. LowerTriangularMaskWithTensorBias)
  8. from aphrodite._C import ops
  9. from aphrodite._C import cache_ops
  10. from aphrodite.modeling.metadata import InputMetadata
  11. from aphrodite.modeling.layers.triton_kernel.prefix_prefill import (
  12. context_attention_fwd)
  13. from aphrodite.common.utils import is_hip
  14. _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  15. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
  16. _PARTITION_SIZE = 512
  17. class PagedAttention(nn.Module):
  18. """MHA/MQA/GQA layer with PagedAttention.
  19. This class takes query, key, and value tensors as input. The input tensors
  20. can either contain prompt tokens or generation tokens.
  21. The class does the following:
  22. 1. Reshape and store the input key and value tensors in the KV cache.
  23. 2. Perform (multi-head/multi-query/grouped-query) attention using either
  24. xformers or the PagedAttention custom op.
  25. 3. Return the output tensor.
  26. """
  27. def __init__(
  28. self,
  29. num_heads: int,
  30. head_size: int,
  31. scale: float,
  32. num_kv_heads: Optional[int] = None,
  33. alibi_slopes: Optional[List[float]] = None,
  34. sliding_window: Optional[int] = None,
  35. ) -> None:
  36. super().__init__()
  37. self.num_heads = num_heads
  38. self.head_size = head_size
  39. self.scale = float(scale)
  40. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  41. self.sliding_window = sliding_window
  42. if alibi_slopes is not None:
  43. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  44. self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
  45. assert self.num_heads % self.num_kv_heads == 0
  46. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  47. if self.head_size not in _SUPPORTED_HEAD_SIZES:
  48. raise ValueError(f"head_size ({self.head_size}) is not supported. "
  49. f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
  50. def forward(
  51. self,
  52. query: torch.Tensor,
  53. key: torch.Tensor,
  54. value: torch.Tensor,
  55. key_cache: Optional[torch.Tensor],
  56. value_cache: Optional[torch.Tensor],
  57. input_metadata: InputMetadata,
  58. ) -> torch.Tensor:
  59. """PagedAttention forward pass.
  60. Args:
  61. query: shape = [batch_size, seq_len, num_heads * head_size]
  62. key: shape = [batch_size, seq_len, num_kv_heads * head_size]
  63. value: shape = [batch_size, seq_len, num_kv_heads * head_size]
  64. key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
  65. block_size, x]
  66. value_cache: shape = [num_blocks, num_kv_heads, head_size,
  67. block_size]
  68. input_metadata: metadata for the inputs.
  69. Returns:
  70. shape = [batch_size, seq_len, num_heads * head_size]
  71. """
  72. batch_size, seq_len, hidden_size = query.shape
  73. # Reshape the query, key, and value tensors.
  74. query = query.view(-1, self.num_heads, self.head_size)
  75. key = key.view(-1, self.num_kv_heads, self.head_size)
  76. value = value.view(-1, self.num_kv_heads, self.head_size)
  77. # Reshape the keys and values and store them in the cache.
  78. # If key_cache and value_cache are not provided, the new key and value
  79. # vectors will not be cached. This happens during the initial memory
  80. # profiling run.
  81. if key_cache is not None and value_cache is not None:
  82. cache_ops.reshape_and_cache(
  83. key,
  84. value,
  85. key_cache,
  86. value_cache,
  87. input_metadata.slot_mapping.flatten(),
  88. input_metadata.kv_cache_dtype,
  89. )
  90. if input_metadata.is_prompt:
  91. # Prompt run.
  92. if self.num_kv_heads != self.num_heads:
  93. # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
  94. # project the key and value tensors to the desired number of
  95. # heads.
  96. # TODO: Use MQA/GQA kernels for higher performance.
  97. query = query.view(query.shape[0], self.num_kv_heads,
  98. self.num_queries_per_kv, query.shape[-1])
  99. key = key[:, :,
  100. None, :].expand(key.shape[0], self.num_kv_heads,
  101. self.num_queries_per_kv,
  102. key.shape[-1])
  103. value = value[:, :, None, :].expand(value.shape[0],
  104. self.num_kv_heads,
  105. self.num_queries_per_kv,
  106. value.shape[-1])
  107. # normal attention
  108. if (key_cache is None or value_cache is None
  109. or input_metadata.block_tables.numel() == 0):
  110. # Set attention bias if not provided. This typically happens at
  111. # the very attention layer of every iteration.
  112. # FIXME: This is a hack.
  113. if input_metadata.attn_bias is None:
  114. if self.alibi_slopes is None:
  115. attn_bias = BlockDiagonalCausalMask.from_seqlens(
  116. [seq_len] * batch_size)
  117. if self.sliding_window is not None:
  118. attn_bias = attn_bias.make_local_attention(
  119. self.sliding_window)
  120. input_metadata.attn_bias = attn_bias
  121. else:
  122. input_metadata.attn_bias = _make_alibi_bias(
  123. self.alibi_slopes, self.num_kv_heads, batch_size,
  124. seq_len, query.dtype)
  125. # TODO: Too many view operations. Let's try to reduce
  126. # them in the future for code readability.
  127. if self.alibi_slopes is None:
  128. query = query.unsqueeze(0)
  129. key = key.unsqueeze(0)
  130. value = value.unsqueeze(0)
  131. else:
  132. query = query.unflatten(0, (batch_size, seq_len))
  133. key = key.unflatten(0, (batch_size, seq_len))
  134. value = value.unflatten(0, (batch_size, seq_len))
  135. out = xops.memory_efficient_attention_forward(
  136. query,
  137. key,
  138. value,
  139. attn_bias=input_metadata.attn_bias,
  140. p=0.0,
  141. scale=self.scale,
  142. op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
  143. (is_hip()) else None,
  144. )
  145. output = out.view_as(query)
  146. else:
  147. # prefix-enabled attention
  148. output = torch.empty_like(query)
  149. context_attention_fwd(
  150. query,
  151. key,
  152. value,
  153. output,
  154. key_cache,
  155. value_cache,
  156. input_metadata.block_tables, # [BS, max_block_per_request]
  157. input_metadata.start_loc,
  158. input_metadata.prompt_lens,
  159. input_metadata.context_lens,
  160. input_metadata.max_seq_len,
  161. getattr(self, "alibi_slopes", None),
  162. )
  163. else:
  164. # Decoding run.
  165. output = _paged_attention(
  166. query,
  167. key_cache,
  168. value_cache,
  169. input_metadata,
  170. self.num_kv_heads,
  171. self.scale,
  172. self.alibi_slopes,
  173. )
  174. # Reshape the output tensor.
  175. return output.view(batch_size, seq_len, hidden_size)
  176. def _make_alibi_bias(
  177. alibi_slopes: torch.Tensor,
  178. num_kv_heads: int,
  179. batch_size: int,
  180. seq_len: int,
  181. dtype: torch.dtype,
  182. ) -> LowerTriangularMaskWithTensorBias:
  183. bias = torch.arange(seq_len, dtype=dtype)
  184. # NOTE: HF uses
  185. # `bias = bias[None, :].repeat(prompt_len, 1)`
  186. # here. We find that both biases give the same results, but
  187. # the bias below more accurately follows the original ALiBi
  188. # paper.
  189. bias = bias[None, :] - bias[:, None]
  190. # When using custom attention bias, xformers requires the bias to
  191. # be sliced from a tensor whose length is a multiple of 8.
  192. padded_len = (seq_len + 7) // 8 * 8
  193. num_heads = alibi_slopes.shape[0]
  194. bias = torch.empty(
  195. batch_size,
  196. num_heads,
  197. seq_len,
  198. padded_len,
  199. device=alibi_slopes.device,
  200. dtype=dtype,
  201. )[:, :, :, :seq_len].copy_(bias)
  202. bias.mul_(alibi_slopes[:, None, None])
  203. if num_heads != num_kv_heads:
  204. bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
  205. attn_bias = LowerTriangularMaskWithTensorBias(bias)
  206. return attn_bias
  207. def _paged_attention(
  208. query: torch.Tensor,
  209. key_cache: torch.Tensor,
  210. value_cache: torch.Tensor,
  211. input_metadata: InputMetadata,
  212. num_kv_heads: int,
  213. scale: float,
  214. alibi_slopes: Optional[torch.Tensor],
  215. ) -> torch.Tensor:
  216. output = torch.empty_like(query)
  217. block_size = value_cache.shape[3]
  218. num_seqs, num_heads, head_size = query.shape
  219. max_num_partitions = (
  220. (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
  221. _PARTITION_SIZE)
  222. # NOTE: We use a simple heuristic to decide whether to use
  223. # PagedAttention V1 or V2. If the number of partitions is 1, we use
  224. # V1 to avoid the overhead of reduction. Also, if the number of
  225. # sequences or heads is large, we use V1 since there is enough work
  226. # to parallelize.
  227. # TODO: Tune this heuristic.
  228. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
  229. use_v1 = input_metadata.max_context_len <= 8192 and (
  230. max_num_partitions == 1 or num_seqs * num_heads > 512)
  231. if use_v1:
  232. # Run PagedAttention V1.
  233. ops.paged_attention_v1(
  234. output,
  235. query,
  236. key_cache,
  237. value_cache,
  238. num_kv_heads,
  239. scale,
  240. input_metadata.block_tables,
  241. input_metadata.context_lens,
  242. block_size,
  243. input_metadata.max_context_len,
  244. alibi_slopes,
  245. input_metadata.kv_cache_dtype,
  246. )
  247. else:
  248. # Run PagedAttention V2.
  249. assert _PARTITION_SIZE % block_size == 0
  250. tmp_output = torch.empty(
  251. size=(num_seqs, num_heads, max_num_partitions, head_size),
  252. dtype=output.dtype,
  253. device=output.device,
  254. )
  255. exp_sums = torch.empty(
  256. size=(num_seqs, num_heads, max_num_partitions),
  257. dtype=torch.float32,
  258. device=output.device,
  259. )
  260. max_logits = torch.empty_like(exp_sums)
  261. ops.paged_attention_v2(
  262. output,
  263. exp_sums,
  264. max_logits,
  265. tmp_output,
  266. query,
  267. key_cache,
  268. value_cache,
  269. num_kv_heads,
  270. scale,
  271. input_metadata.block_tables,
  272. input_metadata.context_lens,
  273. block_size,
  274. input_metadata.max_context_len,
  275. alibi_slopes,
  276. input_metadata.kv_cache_dtype,
  277. )
  278. return output