attention.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """Multi-head attention."""
  2. from typing import List, Optional
  3. import importlib
  4. import torch
  5. import torch.nn as nn
  6. from xformers import ops as xops
  7. from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
  8. LowerTriangularMaskWithTensorBias)
  9. from aphrodite._C import ops
  10. from aphrodite._C import cache_ops
  11. from aphrodite.modeling.metadata import InputMetadata
  12. from aphrodite.modeling.layers.triton_kernel.prefix_prefill import (
  13. context_attention_fwd)
  14. from aphrodite.common.utils import is_hip
  15. _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  16. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
  17. _PARTITION_SIZE = 512
  18. class PagedAttention(nn.Module):
  19. """MHA/MQA/GQA layer with PagedAttention.
  20. This class takes query, key, and value tensors as input. The input tensors
  21. can either contain prompt tokens or generation tokens.
  22. The class does the following:
  23. 1. Reshape and store the input key and value tensors in the KV cache.
  24. 2. Perform (multi-head/multi-query/grouped-query) attention using either
  25. xformers or the PagedAttention custom op.
  26. 3. Return the output tensor.
  27. """
  28. def __init__(
  29. self,
  30. num_heads: int,
  31. head_size: int,
  32. scale: float,
  33. num_kv_heads: Optional[int] = None,
  34. alibi_slopes: Optional[List[float]] = None,
  35. sliding_window: Optional[int] = None,
  36. ) -> None:
  37. super().__init__()
  38. self.num_heads = num_heads
  39. self.head_size = head_size
  40. self.scale = float(scale)
  41. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  42. self.sliding_window = sliding_window
  43. if alibi_slopes is not None:
  44. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  45. self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
  46. assert self.num_heads % self.num_kv_heads == 0
  47. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  48. if self.head_size not in _SUPPORTED_HEAD_SIZES:
  49. raise ValueError(f"head_size ({self.head_size}) is not supported. "
  50. f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
  51. self.use_ref_attention = self.check_use_ref_attention()
  52. def check_use_ref_attention(self) -> bool:
  53. if not is_hip():
  54. return False
  55. # For ROCm, check whether flash attention is installed or not.
  56. # if not, use_ref_attention needs to be True
  57. return importlib.util.find_spec("flash_attn") is None
  58. def ref_masked_attention(
  59. self,
  60. query: torch.Tensor,
  61. key: torch.Tensor,
  62. value: torch.Tensor,
  63. ) -> torch.Tensor:
  64. query = query.view(-1, self.num_heads, self.head_size)
  65. key = key.view(-1, self.num_kv_heads, self.head_size)
  66. value = value.view(-1, self.num_kv_heads, self.head_size)
  67. seq_len, _, _ = query.shape
  68. attn_mask = torch.triu(torch.ones(seq_len,
  69. seq_len,
  70. dtype=query.dtype,
  71. device=query.device),
  72. diagonal=1)
  73. attn_mask = attn_mask * torch.finfo(query.dtype).min
  74. attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
  75. key).float()
  76. attn_weights = attn_weights + attn_mask.float()
  77. attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
  78. out = torch.einsum("hqk,khd->qhd", attn_weights, value)
  79. return out
  80. def forward(
  81. self,
  82. query: torch.Tensor,
  83. key: torch.Tensor,
  84. value: torch.Tensor,
  85. key_cache: Optional[torch.Tensor],
  86. value_cache: Optional[torch.Tensor],
  87. input_metadata: InputMetadata,
  88. kv_quant_param: List[float] = None,
  89. ) -> torch.Tensor:
  90. """PagedAttention forward pass.
  91. Args:
  92. query: shape = [batch_size, seq_len, num_heads * head_size]
  93. key: shape = [batch_size, seq_len, num_kv_heads * head_size]
  94. value: shape = [batch_size, seq_len, num_kv_heads * head_size]
  95. key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
  96. block_size, x]
  97. value_cache: shape = [num_blocks, num_kv_heads, head_size,
  98. block_size]
  99. input_metadata: metadata for the inputs.
  100. Returns:
  101. shape = [batch_size, seq_len, num_heads * head_size]
  102. """
  103. batch_size, seq_len, hidden_size = query.shape
  104. # Reshape the query, key, and value tensors.
  105. query = query.view(-1, self.num_heads, self.head_size)
  106. key = key.view(-1, self.num_kv_heads, self.head_size)
  107. value = value.view(-1, self.num_kv_heads, self.head_size)
  108. # FIXME: Remove this when all models support int8 kv cache
  109. kv_quant_param = [1.0, 0.0, 1.0, 0.0
  110. ] if kv_quant_param is None else kv_quant_param
  111. # Reshape the keys and values and store them in the cache.
  112. # If key_cache and value_cache are not provided, the new key and value
  113. # vectors will not be cached. This happens during the initial memory
  114. # profiling run.
  115. if key_cache is not None and value_cache is not None:
  116. cache_ops.reshape_and_cache(
  117. key,
  118. value,
  119. key_cache,
  120. value_cache,
  121. input_metadata.slot_mapping.flatten(),
  122. input_metadata.kv_cache_dtype,
  123. *kv_quant_param,
  124. )
  125. if input_metadata.is_prompt:
  126. # Prompt run.
  127. if self.num_kv_heads != self.num_heads:
  128. # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
  129. # project the key and value tensors to the desired number of
  130. # heads.
  131. # TODO: Use MQA/GQA kernels for higher performance.
  132. query = query.view(query.shape[0], self.num_kv_heads,
  133. self.num_queries_per_kv, query.shape[-1])
  134. key = key[:, :,
  135. None, :].expand(key.shape[0], self.num_kv_heads,
  136. self.num_queries_per_kv,
  137. key.shape[-1])
  138. value = value[:, :, None, :].expand(value.shape[0],
  139. self.num_kv_heads,
  140. self.num_queries_per_kv,
  141. value.shape[-1])
  142. # normal attention
  143. if (key_cache is None or value_cache is None
  144. or input_metadata.block_tables.numel() == 0):
  145. # Set attention bias if not provided. This typically happens at
  146. # the very attention layer of every iteration.
  147. # FIXME: This is a hack.
  148. if input_metadata.attn_bias is None:
  149. if self.alibi_slopes is None:
  150. attn_bias = BlockDiagonalCausalMask.from_seqlens(
  151. [seq_len] * batch_size)
  152. if self.sliding_window is not None:
  153. attn_bias = attn_bias.make_local_attention(
  154. self.sliding_window)
  155. input_metadata.attn_bias = attn_bias
  156. else:
  157. input_metadata.attn_bias = _make_alibi_bias(
  158. self.alibi_slopes, self.num_kv_heads, batch_size,
  159. seq_len, query.dtype)
  160. if self.use_ref_attention:
  161. output = self.ref_masked_attention(
  162. query,
  163. key,
  164. value,
  165. )
  166. return output.reshape(batch_size, seq_len, hidden_size)
  167. # TODO: Too many view operations. Let's try to reduce
  168. # them in the future for code readability.
  169. if self.alibi_slopes is None:
  170. query = query.unsqueeze(0)
  171. key = key.unsqueeze(0)
  172. value = value.unsqueeze(0)
  173. else:
  174. query = query.unflatten(0, (batch_size, seq_len))
  175. key = key.unflatten(0, (batch_size, seq_len))
  176. value = value.unflatten(0, (batch_size, seq_len))
  177. out = xops.memory_efficient_attention_forward(
  178. query,
  179. key,
  180. value,
  181. attn_bias=input_metadata.attn_bias,
  182. p=0.0,
  183. scale=self.scale,
  184. op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
  185. (is_hip()) else None,
  186. )
  187. output = out.view_as(query)
  188. else:
  189. # prefix-enabled attention
  190. output = torch.empty_like(query)
  191. context_attention_fwd(
  192. query,
  193. key,
  194. value,
  195. output,
  196. key_cache,
  197. value_cache,
  198. input_metadata.block_tables, # [BS, max_block_per_request]
  199. input_metadata.start_loc,
  200. input_metadata.prompt_lens,
  201. input_metadata.context_lens,
  202. input_metadata.max_seq_len,
  203. getattr(self, "alibi_slopes", None),
  204. )
  205. else:
  206. # Decoding run.
  207. output = _paged_attention(
  208. query,
  209. key_cache,
  210. value_cache,
  211. input_metadata,
  212. self.num_kv_heads,
  213. self.scale,
  214. self.alibi_slopes,
  215. kv_quant_param,
  216. )
  217. # Reshape the output tensor.
  218. return output.view(batch_size, seq_len, hidden_size)
  219. def _make_alibi_bias(
  220. alibi_slopes: torch.Tensor,
  221. num_kv_heads: int,
  222. batch_size: int,
  223. seq_len: int,
  224. dtype: torch.dtype,
  225. ) -> LowerTriangularMaskWithTensorBias:
  226. bias = torch.arange(seq_len, dtype=dtype)
  227. # NOTE: HF uses
  228. # `bias = bias[None, :].repeat(prompt_len, 1)`
  229. # here. We find that both biases give the same results, but
  230. # the bias below more accurately follows the original ALiBi
  231. # paper.
  232. bias = bias[None, :] - bias[:, None]
  233. # When using custom attention bias, xformers requires the bias to
  234. # be sliced from a tensor whose length is a multiple of 8.
  235. padded_len = (seq_len + 7) // 8 * 8
  236. num_heads = alibi_slopes.shape[0]
  237. bias = torch.empty(
  238. batch_size,
  239. num_heads,
  240. seq_len,
  241. padded_len,
  242. device=alibi_slopes.device,
  243. dtype=dtype,
  244. )[:, :, :, :seq_len].copy_(bias)
  245. bias.mul_(alibi_slopes[:, None, None])
  246. if num_heads != num_kv_heads:
  247. bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
  248. attn_bias = LowerTriangularMaskWithTensorBias(bias)
  249. return attn_bias
  250. def _paged_attention(
  251. query: torch.Tensor,
  252. key_cache: torch.Tensor,
  253. value_cache: torch.Tensor,
  254. input_metadata: InputMetadata,
  255. num_kv_heads: int,
  256. scale: float,
  257. alibi_slopes: Optional[torch.Tensor],
  258. kv_quant_param: List[float],
  259. ) -> torch.Tensor:
  260. output = torch.empty_like(query)
  261. block_size = value_cache.shape[3]
  262. num_seqs, num_heads, head_size = query.shape
  263. max_num_partitions = (
  264. (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
  265. _PARTITION_SIZE)
  266. # NOTE: We use a simple heuristic to decide whether to use
  267. # PagedAttention V1 or V2. If the number of partitions is 1, we use
  268. # V1 to avoid the overhead of reduction. Also, if the number of
  269. # sequences or heads is large, we use V1 since there is enough work
  270. # to parallelize.
  271. # TODO: Tune this heuristic.
  272. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
  273. use_v1 = input_metadata.max_context_len <= 8192 and (
  274. max_num_partitions == 1 or num_seqs * num_heads > 512)
  275. if use_v1:
  276. # Run PagedAttention V1.
  277. ops.paged_attention_v1(
  278. output,
  279. query,
  280. key_cache,
  281. value_cache,
  282. num_kv_heads,
  283. scale,
  284. input_metadata.block_tables,
  285. input_metadata.context_lens,
  286. block_size,
  287. input_metadata.max_context_len,
  288. alibi_slopes,
  289. input_metadata.kv_cache_dtype,
  290. *kv_quant_param,
  291. )
  292. else:
  293. # Run PagedAttention V2.
  294. assert _PARTITION_SIZE % block_size == 0
  295. tmp_output = torch.empty(
  296. size=(num_seqs, num_heads, max_num_partitions, head_size),
  297. dtype=output.dtype,
  298. device=output.device,
  299. )
  300. exp_sums = torch.empty(
  301. size=(num_seqs, num_heads, max_num_partitions),
  302. dtype=torch.float32,
  303. device=output.device,
  304. )
  305. max_logits = torch.empty_like(exp_sums)
  306. ops.paged_attention_v2(
  307. output,
  308. exp_sums,
  309. max_logits,
  310. tmp_output,
  311. query,
  312. key_cache,
  313. value_cache,
  314. num_kv_heads,
  315. scale,
  316. input_metadata.block_tables,
  317. input_metadata.context_lens,
  318. block_size,
  319. input_metadata.max_context_len,
  320. alibi_slopes,
  321. input_metadata.kv_cache_dtype,
  322. *kv_quant_param,
  323. )
  324. return output