attention.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. """
  2. Multi-head Paged Attention by Woosuk et al. (vLLM) Copyright (c) 2023.
  3. https://vllm.ai/
  4. """
  5. from typing import Any, Dict, List, Optional
  6. import torch
  7. import torch.nn as nn
  8. from xformers import ops as xops
  9. from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
  10. LowerTriangularMaskWithTensorBias)
  11. from aphrodite import attention_ops
  12. from aphrodite import cache_ops
  13. from aphrodite.modeling.metadata import InputMetadata
  14. from aphrodite.modeling.layers.rotary_embedding import (
  15. DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
  16. RotaryEmbedding, YaRNScalingRotaryEmbedding)
  17. _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  18. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
  19. _PARTITION_SIZE = 512
  20. class PagedAttention(nn.Module):
  21. # pylint: disable=line-too-long
  22. """GPT-style multi-head PagedAttention.
  23. This class takes query, key, and value tensors as input. The input tensors
  24. can either contain prompt tokens or generation tokens, in addition to
  25. paddings.
  26. The class does the following:
  27. 1. Perform multi_query_kv_attention for the prompts. This operation does
  28. not use the KV cache.
  29. 2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
  30. operations are issued by the cache engine before executing the forward
  31. pass of the model, and they are executed asynchronously.
  32. 3. Reshape and store the input key and value tensors in the KV cache.
  33. 4. Perform single_query_cached_kv_attention for the generation tokens.
  34. This operation reads the previous key and value tensors from the KV
  35. cache.
  36. 5. Return the output tensor.
  37. """
  38. def __init__(self,
  39. num_heads: int,
  40. head_size: int,
  41. scale: float,
  42. num_kv_heads: Optional[int] = None,
  43. sliding_window: Optional[int] = None) -> None:
  44. super().__init__()
  45. self.num_heads = num_heads
  46. self.head_size = head_size
  47. self.scale = float(scale)
  48. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  49. self.sliding_window = sliding_window
  50. assert self.num_heads % self.num_kv_heads == 0
  51. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  52. self.head_mapping = torch.repeat_interleave(
  53. torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
  54. self.num_queries_per_kv)
  55. if self.head_size not in _SUPPORTED_HEAD_SIZES:
  56. raise ValueError(f"head_size ({self.head_size}) is not supported. "
  57. f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
  58. def set_attn_bias(
  59. self,
  60. input_metadata: InputMetadata,
  61. dtype: torch.dtype,
  62. ) -> None:
  63. del dtype # Unused.
  64. if input_metadata.attn_bias is not None:
  65. # Already set by a previous layer.
  66. return
  67. prompt_lens = [input_metadata.max_prompt_len
  68. ] * input_metadata.num_prompts
  69. attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
  70. if self.sliding_window is not None:
  71. attn_bias = attn_bias.make_local_attention(self.sliding_window)
  72. input_metadata.attn_bias = attn_bias
  73. def multi_query_kv_attention(
  74. self,
  75. output: torch.Tensor,
  76. query: torch.Tensor,
  77. key: torch.Tensor,
  78. value: torch.Tensor,
  79. input_metadata: InputMetadata,
  80. ) -> torch.Tensor:
  81. """Normal attention for the prompt tokens.
  82. Args:
  83. output: shape = [num_prompt_tokens, num_heads, head_size]
  84. query: shape = [num_prompt_tokens, num_heads, head_size]
  85. key: shape = [num_prompt_tokens, num_kv_heads, head_size]
  86. value: shape = [num_prompt_tokens, num_kv_heads, head_size]
  87. input_metadata: metadata for paged attention.
  88. """
  89. if self.num_kv_heads != self.num_heads:
  90. # Project the key and value tensors to the desired number of heads.
  91. key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
  92. value = torch.repeat_interleave(value,
  93. self.num_queries_per_kv,
  94. dim=1)
  95. # TODO: The unsqueeze op may incur some CPU overhead. Optimize.
  96. out = xops.memory_efficient_attention_forward(
  97. query.unsqueeze(0),
  98. key.unsqueeze(0),
  99. value.unsqueeze(0),
  100. attn_bias=input_metadata.attn_bias,
  101. p=0.0,
  102. scale=self.scale,
  103. )
  104. # TODO: Unnecessary copy. Optimize.
  105. output.copy_(out.squeeze(0))
  106. return output
  107. def get_alibi_slopes(self) -> Optional[torch.Tensor]:
  108. """Returns the slopes for the alibi attention bias.
  109. Returns:
  110. slopes: shape = [num_heads]
  111. """
  112. return None
  113. def single_query_cached_kv_attention(
  114. self,
  115. output: torch.Tensor,
  116. query: torch.Tensor,
  117. key_cache: torch.Tensor,
  118. value_cache: torch.Tensor,
  119. input_metadata: InputMetadata,
  120. alibi_slopes: Optional[torch.Tensor],
  121. ) -> None:
  122. """PagedAttention for the generation tokens.
  123. Args:
  124. output: shape = [num_generation_tokens, num_heads, head_size]
  125. query: shape = [num_generation_tokens, num_heads, head_size]
  126. key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
  127. block_size, x]
  128. value_cache: shape = [num_blocks, num_kv_heads, head_size,
  129. block_size]
  130. input_metadata: metadata for paged attention.
  131. alibi_slopes: shape = [num_heads]
  132. """
  133. block_size = value_cache.shape[3]
  134. num_seqs, num_heads, head_size = query.shape
  135. max_num_partitions = (
  136. (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
  137. _PARTITION_SIZE)
  138. # NOTE: We use a simple heuristic to decide whether to use
  139. # PagedAttention V1 or V2. If the number of partitions is 1, we use
  140. # V1 to avoid the overhead of reduction. Also, if the number of
  141. # sequences or heads is large, we use V1 since there is enough work
  142. # to parallelize.
  143. # TODO: Tune this heuristic.
  144. use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
  145. if use_v1:
  146. # Run PagedAttention V1.
  147. attention_ops.paged_attention_v1(
  148. output,
  149. query,
  150. key_cache,
  151. value_cache,
  152. self.head_mapping,
  153. self.scale,
  154. input_metadata.block_tables,
  155. input_metadata.context_lens,
  156. block_size,
  157. input_metadata.max_context_len,
  158. alibi_slopes,
  159. )
  160. else:
  161. # Run PagedAttention V2.
  162. assert _PARTITION_SIZE % block_size == 0
  163. tmp_output = torch.empty(
  164. size=(num_seqs, num_heads, max_num_partitions, head_size),
  165. dtype=output.dtype,
  166. device=output.device,
  167. )
  168. exp_sums = torch.empty(
  169. size=(num_seqs, num_heads, max_num_partitions),
  170. dtype=torch.float32,
  171. device=output.device,
  172. )
  173. max_logits = torch.empty_like(exp_sums)
  174. attention_ops.paged_attention_v2(
  175. output,
  176. exp_sums,
  177. max_logits,
  178. tmp_output,
  179. query,
  180. key_cache,
  181. value_cache,
  182. self.head_mapping,
  183. self.scale,
  184. input_metadata.block_tables,
  185. input_metadata.context_lens,
  186. block_size,
  187. input_metadata.max_context_len,
  188. alibi_slopes,
  189. )
  190. def forward(
  191. self,
  192. query: torch.Tensor,
  193. key: torch.Tensor,
  194. value: torch.Tensor,
  195. key_cache: Optional[torch.Tensor],
  196. value_cache: Optional[torch.Tensor],
  197. input_metadata: InputMetadata,
  198. cache_event: Optional[torch.cuda.Event],
  199. ) -> torch.Tensor:
  200. """PagedAttention forward pass.
  201. NOTE: The query, key, and value tensors must be sliced from a qkv
  202. tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
  203. Args:
  204. query: shape = [batch_size, seq_len, num_heads * head_size]
  205. key: shape = [batch_size, seq_len, num_kv_heads * head_size]
  206. value: shape = [batch_size, num_kv_heads * head_size]
  207. key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
  208. block_size, x]
  209. value_cache: shape = [num_blocks, num_kv_heads, head_size,
  210. block_size]
  211. input_metadata: metadata for paged attention.
  212. cache_event: event to wait for the cache operations to finish.
  213. Returns:
  214. shape = [batch_size, seq_len, num_heads * head_size]
  215. """
  216. batch_size, seq_len, _ = query.shape
  217. # Reshape the query, key, and value tensors.
  218. query = query.view(-1, self.num_heads, self.head_size)
  219. key = key.view(-1, self.num_kv_heads, self.head_size)
  220. value = value.view(-1, self.num_kv_heads, self.head_size)
  221. # Pre-allocate the output tensor.
  222. output = torch.empty_like(query)
  223. # Compute the attention op for prompts.
  224. num_prompt_tokens = input_metadata.num_prompt_tokens
  225. if num_prompt_tokens > 0:
  226. # Prompt run.
  227. assert input_metadata.num_generation_tokens == 0
  228. self.set_attn_bias(input_metadata, dtype=query.dtype)
  229. self.multi_query_kv_attention(
  230. output,
  231. query,
  232. key,
  233. value,
  234. input_metadata,
  235. )
  236. # Wait until the cache op is done.
  237. if cache_event is not None:
  238. cache_event.wait()
  239. # Reshape the keys and values and store them in the cache.
  240. # When key_cache and value_cache are not provided, the new key
  241. # and value vectors will not be cached.
  242. if key_cache is not None and value_cache is not None:
  243. key_to_cache = key
  244. value_to_cache = value
  245. slot_mapping = input_metadata.slot_mapping.view(-1)
  246. if input_metadata.to_cache is not None:
  247. key_to_cache = key_to_cache[input_metadata.to_cache]
  248. value_to_cache = value_to_cache[input_metadata.to_cache]
  249. slot_mapping = slot_mapping[input_metadata.to_cache]
  250. cache_ops.reshape_and_cache(
  251. key_to_cache,
  252. value_to_cache,
  253. key_cache,
  254. value_cache,
  255. slot_mapping,
  256. )
  257. if input_metadata.num_generation_tokens > 0:
  258. # Decoding run.
  259. assert input_metadata.num_prompt_tokens == 0
  260. assert key_cache is not None and value_cache is not None, (
  261. "key_cache and value_cache must be provided when "
  262. "generating tokens.")
  263. # Compute the attention op for generation tokens.
  264. self.single_query_cached_kv_attention(output, query, key_cache,
  265. value_cache, input_metadata,
  266. self.get_alibi_slopes())
  267. # Reshape the output tensor.
  268. # NOTE: The output tensor may include paddings.
  269. return output.view(batch_size, seq_len,
  270. self.num_heads * self.head_size)
  271. class PagedAttentionWithRoPE(PagedAttention):
  272. """PagedAttention with rotary positional embedding."""
  273. def __init__(
  274. self,
  275. num_heads: int,
  276. head_size: int,
  277. scale: float,
  278. rotary_dim: int,
  279. max_position: int = 8192,
  280. base: int = 10000,
  281. num_kv_heads: Optional[int] = None,
  282. is_neox_style: bool = True,
  283. rope_scaling: Optional[Dict[str, Any]] = None,
  284. sliding_window: Optional[int] = None,
  285. ) -> None:
  286. super().__init__(num_heads,
  287. head_size,
  288. scale,
  289. num_kv_heads,
  290. sliding_window=sliding_window)
  291. if rope_scaling is None:
  292. self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
  293. max_position, base,
  294. is_neox_style)
  295. else:
  296. scaling_type = rope_scaling["type"]
  297. scaling_factor = rope_scaling["factor"]
  298. if scaling_type == "linear":
  299. self.rotary_emb = LinearScalingRotaryEmbedding(
  300. head_size, rotary_dim, max_position, base, is_neox_style,
  301. scaling_factor)
  302. elif scaling_type == "dynamic":
  303. self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
  304. head_size, rotary_dim, max_position, base, is_neox_style,
  305. scaling_factor)
  306. elif scaling_type == "yarn":
  307. new_max_position = rope_scaling[
  308. "original_max_position_embeddings"]
  309. assert max_position == new_max_position * scaling_factor
  310. extra_kwargs = {
  311. k: v
  312. for k, v in rope_scaling.items()
  313. if k in ("extrapolation_factor", "attn_factor",
  314. "beta_fast", "beta_slow")
  315. }
  316. self.rotary_emb = YaRNScalingRotaryEmbedding(
  317. head_size, rotary_dim, new_max_position, base,
  318. is_neox_style, scaling_factor, **extra_kwargs)
  319. else:
  320. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  321. def forward(
  322. self,
  323. positions: torch.Tensor,
  324. query: torch.Tensor,
  325. key: torch.Tensor,
  326. value: torch.Tensor,
  327. key_cache: torch.Tensor,
  328. value_cache: torch.Tensor,
  329. input_metadata: InputMetadata,
  330. cache_event: Optional[torch.cuda.Event],
  331. ) -> torch.Tensor:
  332. """ PagedAttention forward pass with rotary embedding.
  333. Args:
  334. positions: shape = [batch_size, seq_len]
  335. query: shape = [batch_size, seq_len, num_heads * head_size]
  336. key: shape = [batch_size, seq_len, num_kv_heads * head_size]
  337. value: shape = [batch_size, seq_len, num_kv_heads * head_size]
  338. key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
  339. block_size, x]
  340. value_cache: shape = [num_blocks, num_kv_heads, head_size,
  341. block_size]
  342. input_metadata: metadata for paged attention.
  343. cache_event: event to wait for the cache operations to finish.
  344. Returns:
  345. shape = [batch_size, seq_len, num_heads * head_size]
  346. """
  347. # Apply rotary embedding to the query and key before passing them
  348. # to the attention op.
  349. query, key = self.rotary_emb(positions, query, key)
  350. return super().forward(
  351. query,
  352. key,
  353. value,
  354. key_cache,
  355. value_cache,
  356. input_metadata,
  357. cache_event,
  358. )
  359. class PagedAttentionWithALiBi(PagedAttention):
  360. """PagedAttention with ALiBi attention bias."""
  361. def __init__(self,
  362. num_heads: int,
  363. head_size: int,
  364. scale: float,
  365. slopes: List[float],
  366. num_kv_heads: Optional[int] = None) -> None:
  367. super().__init__(num_heads, head_size, scale, num_kv_heads)
  368. assert len(slopes) == num_heads
  369. slopes = torch.tensor(slopes, dtype=torch.float32)
  370. self.register_buffer("alibi_slopes", slopes, persistent=False)
  371. def set_attn_bias(self, input_metadata: InputMetadata,
  372. dtype: torch.dtype) -> None:
  373. if input_metadata.attn_bias is not None:
  374. # Already set by a previous layer.
  375. return
  376. # Generates ALiBi mask based on the max prompt length.
  377. max_prompt_len = input_metadata.max_prompt_len
  378. bias = torch.arange(max_prompt_len, dtype=dtype)
  379. # NOTE(zhuohan): HF uses
  380. # `bias = bias[None, :].repeat(prompt_len, 1)`
  381. # here. We find that both biases give the same results, but
  382. # the bias below more accurately follows the original ALiBi
  383. # paper.
  384. bias = bias[None, :] - bias[:, None]
  385. bias = bias.to(self.alibi_slopes.device)
  386. # When using custom attention bias, xformers requires the bias to
  387. # be sliced from a tensor whose length is a multiple of 8.
  388. padded_len = (max_prompt_len + 7) // 8 * 8
  389. bias = torch.empty(
  390. input_metadata.num_prompts,
  391. self.num_heads,
  392. max_prompt_len,
  393. padded_len,
  394. device=self.alibi_slopes.device,
  395. dtype=dtype,
  396. )[:, :, :, :max_prompt_len].copy_(bias)
  397. bias.mul_(self.alibi_slopes[:, None, None])
  398. attn_bias = LowerTriangularMaskWithTensorBias(bias)
  399. input_metadata.attn_bias = attn_bias
  400. def multi_query_kv_attention(
  401. self,
  402. output: torch.Tensor,
  403. query: torch.Tensor,
  404. key: torch.Tensor,
  405. value: torch.Tensor,
  406. input_metadata: InputMetadata,
  407. ) -> torch.Tensor:
  408. """Attention with ALiBi bias for the prompt tokens.
  409. Args:
  410. output: shape = [num_prompt_tokens, num_heads, head_size]
  411. query: shape = [num_prompt_tokens, num_heads, head_size]
  412. key: shape = [num_prompt_tokens, num_kv_heads, head_size]
  413. value: shape = [num_prompt_tokens, num_kv_heads, head_size]
  414. input_metadata: metadata for paged attention.
  415. """
  416. if self.num_kv_heads != self.num_heads:
  417. # Project the key and value tensors to the desired number of heads.
  418. key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
  419. value = torch.repeat_interleave(value,
  420. self.num_queries_per_kv,
  421. dim=1)
  422. batch_size = input_metadata.num_prompts
  423. seq_len = input_metadata.max_prompt_len
  424. out = xops.memory_efficient_attention_forward(
  425. query.view(batch_size, seq_len, self.num_heads, self.head_size),
  426. key.view(batch_size, seq_len, self.num_heads, self.head_size),
  427. value.view(batch_size, seq_len, self.num_heads, self.head_size),
  428. attn_bias=input_metadata.attn_bias,
  429. p=0.0,
  430. scale=self.scale,
  431. )
  432. # TODO: Unnecessary copy. Optimize.
  433. output.copy_(out.view(-1, self.num_heads, self.head_size))
  434. return output
  435. def get_alibi_slopes(self) -> Optional[torch.Tensor]:
  436. return self.alibi_slopes