paged_attn.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. from aphrodite._C import cache_ops
  5. from aphrodite._C import ops
  6. from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
  7. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
  8. _PARTITION_SIZE = 512
  9. @dataclass
  10. class PagedAttentionMetadata:
  11. """Metadata for PagedAttention."""
  12. # (batch_size,). The length of context (tokens stored in KV cache) per
  13. # sequence. WARNING: When it is a prefill request, it doesn't include new
  14. # tokens. When it is for decoding, it includes a new token.
  15. context_lens: Optional[torch.Tensor]
  16. # Maximum context length in the batch.
  17. max_context_len: Optional[int]
  18. # (batch_size, max_blocks_per_seq).
  19. # Block addresses per sequence. (Seq id -> list of physical block)
  20. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  21. # in the kv cache. Each block can contain up to block_size tokens.
  22. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  23. # captured.
  24. block_tables: Optional[torch.Tensor]
  25. class PagedAttention:
  26. @staticmethod
  27. def get_supported_head_sizes() -> List[int]:
  28. return [64, 80, 96, 112, 128, 256]
  29. @staticmethod
  30. def get_kv_cache_shape(
  31. num_blocks: int,
  32. block_size: int,
  33. num_kv_heads: int,
  34. head_size: int,
  35. ) -> Tuple[int, ...]:
  36. return (2, num_blocks, block_size * num_kv_heads * head_size)
  37. @staticmethod
  38. def split_kv_cache(
  39. kv_cache: torch.Tensor,
  40. num_kv_heads: int,
  41. head_size: int,
  42. ) -> Tuple[torch.Tensor, torch.Tensor]:
  43. x = 16 // kv_cache.element_size()
  44. num_blocks = kv_cache.shape[1]
  45. key_cache = kv_cache[0]
  46. key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
  47. -1, x)
  48. value_cache = kv_cache[1]
  49. value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
  50. return key_cache, value_cache
  51. @staticmethod
  52. def write_to_paged_cache(
  53. key: torch.Tensor,
  54. value: torch.Tensor,
  55. key_cache: torch.Tensor,
  56. value_cache: torch.Tensor,
  57. slot_mapping: torch.Tensor,
  58. kv_cache_dtype: str,
  59. kv_scale: float,
  60. ) -> None:
  61. cache_ops.reshape_and_cache(
  62. key,
  63. value,
  64. key_cache,
  65. value_cache,
  66. slot_mapping.flatten(),
  67. kv_cache_dtype,
  68. kv_scale,
  69. )
  70. @staticmethod
  71. def forward_decode(
  72. query: torch.Tensor,
  73. key_cache: torch.Tensor,
  74. value_cache: torch.Tensor,
  75. block_tables: torch.Tensor,
  76. context_lens: torch.Tensor,
  77. max_context_len: int,
  78. kv_cache_dtype: str,
  79. num_kv_heads: int,
  80. scale: float,
  81. alibi_slopes: Optional[torch.Tensor],
  82. kv_scale: float,
  83. ) -> torch.Tensor:
  84. output = torch.empty_like(query)
  85. block_size = value_cache.shape[3]
  86. num_seqs, num_heads, head_size = query.shape
  87. max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) //
  88. _PARTITION_SIZE)
  89. # NOTE: We use a simple heuristic to decide whether to use
  90. # PagedAttention V1 or V2. If the number of partitions is 1, we use
  91. # V1 to avoid the overhead of reduction. Also, if the number of
  92. # sequences or heads is large, we use V1 since there is enough work
  93. # to parallelize.
  94. # TODO: Tune this heuristic.
  95. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
  96. use_v1 = (max_context_len <= 8192
  97. and (max_num_partitions == 1 or num_seqs * num_heads > 512))
  98. if use_v1:
  99. # Run PagedAttention V1.
  100. ops.paged_attention_v1(
  101. output,
  102. query,
  103. key_cache,
  104. value_cache,
  105. num_kv_heads,
  106. scale,
  107. block_tables,
  108. context_lens,
  109. block_size,
  110. max_context_len,
  111. alibi_slopes,
  112. kv_cache_dtype,
  113. kv_scale,
  114. )
  115. else:
  116. # Run PagedAttention V2.
  117. assert _PARTITION_SIZE % block_size == 0
  118. tmp_output = torch.empty(
  119. size=(num_seqs, num_heads, max_num_partitions, head_size),
  120. dtype=output.dtype,
  121. device=output.device,
  122. )
  123. exp_sums = torch.empty(
  124. size=(num_seqs, num_heads, max_num_partitions),
  125. dtype=torch.float32,
  126. device=output.device,
  127. )
  128. max_logits = torch.empty_like(exp_sums)
  129. ops.paged_attention_v2(
  130. output,
  131. exp_sums,
  132. max_logits,
  133. tmp_output,
  134. query,
  135. key_cache,
  136. value_cache,
  137. num_kv_heads,
  138. scale,
  139. block_tables,
  140. context_lens,
  141. block_size,
  142. max_context_len,
  143. alibi_slopes,
  144. kv_cache_dtype,
  145. kv_scale,
  146. )
  147. return output
  148. @staticmethod
  149. def forward_prefix(
  150. query: torch.Tensor,
  151. key: torch.Tensor,
  152. value: torch.Tensor,
  153. key_cache: torch.Tensor,
  154. value_cache: torch.Tensor,
  155. block_tables: torch.Tensor,
  156. subquery_start_loc: torch.Tensor,
  157. prompt_lens_tensor: torch.Tensor,
  158. context_lens: torch.Tensor,
  159. max_subquery_len: int,
  160. alibi_slopes: Optional[torch.Tensor],
  161. ) -> torch.Tensor:
  162. output = torch.empty_like(query)
  163. context_attention_fwd(
  164. query,
  165. key,
  166. value,
  167. output,
  168. key_cache,
  169. value_cache,
  170. block_tables,
  171. # subquery_start_loc is (batch_size + 1,)
  172. subquery_start_loc[:-1],
  173. prompt_lens_tensor,
  174. context_lens,
  175. max_subquery_len,
  176. alibi_slopes,
  177. )
  178. return output
  179. @staticmethod
  180. def swap_blocks(
  181. src_kv_cache: torch.Tensor,
  182. dst_kv_cache: torch.Tensor,
  183. src_to_dst: Dict[int, int],
  184. ) -> None:
  185. src_key_cache = src_kv_cache[0]
  186. dst_key_cache = dst_kv_cache[0]
  187. cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  188. src_value_cache = src_kv_cache[1]
  189. dst_value_cache = dst_kv_cache[1]
  190. cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  191. @staticmethod
  192. def copy_blocks(
  193. kv_caches: List[torch.Tensor],
  194. src_to_dists: Dict[int, List[int]],
  195. ) -> None:
  196. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  197. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  198. cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)