1
0

paged_attn.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from dataclasses import dataclass
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.triton_utils import HAS_TRITON
  6. if HAS_TRITON:
  7. from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
  8. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
  9. _PARTITION_SIZE = 512
  10. @dataclass
  11. class PagedAttentionMetadata:
  12. """Metadata for PagedAttention."""
  13. # (batch_size,). The length of sequences (entire tokens seen so far) per
  14. # sequence.
  15. seq_lens_tensor: Optional[torch.Tensor]
  16. # Maximum sequence length in the batch. 0 if it is prefill-only batch.
  17. max_decode_seq_len: int
  18. # (batch_size, max_blocks_per_seq).
  19. # Block addresses per sequence. (Seq id -> list of physical block)
  20. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  21. # in the kv cache. Each block can contain up to block_size tokens.
  22. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  23. # captured.
  24. block_tables: Optional[torch.Tensor]
  25. class PagedAttention:
  26. @staticmethod
  27. def get_supported_head_sizes() -> List[int]:
  28. return [64, 80, 96, 112, 120, 128, 192, 256]
  29. @staticmethod
  30. def get_kv_cache_shape(
  31. num_blocks: int,
  32. block_size: int,
  33. num_kv_heads: int,
  34. head_size: int,
  35. ) -> Tuple[int, ...]:
  36. return (2, num_blocks, block_size * num_kv_heads * head_size)
  37. @staticmethod
  38. def split_kv_cache(
  39. kv_cache: torch.Tensor,
  40. num_kv_heads: int,
  41. head_size: int,
  42. ) -> Tuple[torch.Tensor, torch.Tensor]:
  43. x = 16 // kv_cache.element_size()
  44. num_blocks = kv_cache.shape[1]
  45. key_cache = kv_cache[0]
  46. key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
  47. -1, x)
  48. value_cache = kv_cache[1]
  49. value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
  50. return key_cache, value_cache
  51. @staticmethod
  52. def write_to_paged_cache(
  53. key: torch.Tensor,
  54. value: torch.Tensor,
  55. key_cache: torch.Tensor,
  56. value_cache: torch.Tensor,
  57. slot_mapping: torch.Tensor,
  58. kv_cache_dtype: str,
  59. k_scale: float,
  60. v_scale: float,
  61. ) -> None:
  62. ops.reshape_and_cache(
  63. key,
  64. value,
  65. key_cache,
  66. value_cache,
  67. slot_mapping.flatten(),
  68. kv_cache_dtype,
  69. k_scale,
  70. v_scale,
  71. )
  72. @staticmethod
  73. def forward_decode(
  74. query: torch.Tensor,
  75. key_cache: torch.Tensor,
  76. value_cache: torch.Tensor,
  77. block_tables: torch.Tensor,
  78. seq_lens: torch.Tensor,
  79. max_seq_len: int,
  80. kv_cache_dtype: str,
  81. num_kv_heads: int,
  82. scale: float,
  83. alibi_slopes: Optional[torch.Tensor],
  84. k_scale: float,
  85. v_scale: float,
  86. tp_rank: int = 0,
  87. blocksparse_local_blocks: int = 0,
  88. blocksparse_vert_stride: int = 0,
  89. blocksparse_block_size: int = 64,
  90. blocksparse_head_sliding_step: int = 0,
  91. ) -> torch.Tensor:
  92. if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
  93. # use blocksparse paged attention
  94. block_size = value_cache.size(-1)
  95. assert (blocksparse_block_size > 0 and
  96. blocksparse_block_size % block_size == 0), \
  97. (f"{blocksparse_block_size=} needs to be a multiple of"
  98. f"{block_size=} used in block_tables.")
  99. output = torch.empty_like(query)
  100. block_size = value_cache.shape[3]
  101. num_seqs, num_heads, head_size = query.shape
  102. max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
  103. _PARTITION_SIZE)
  104. # NOTE: We use a simple heuristic to decide whether to use
  105. # PagedAttention V1 or V2. If the number of partitions is 1, we use
  106. # V1 to avoid the overhead of reduction. Also, if the number of
  107. # sequences or heads is large, we use V1 since there is enough work
  108. # to parallelize.
  109. # TODO: Tune this heuristic.
  110. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
  111. use_v1 = (max_seq_len <= 8192
  112. and (max_num_partitions == 1 or num_seqs * num_heads > 512))
  113. if use_v1:
  114. # Run PagedAttention V1.
  115. ops.paged_attention_v1(
  116. output,
  117. query,
  118. key_cache,
  119. value_cache,
  120. num_kv_heads,
  121. scale,
  122. block_tables,
  123. seq_lens,
  124. block_size,
  125. max_seq_len,
  126. alibi_slopes,
  127. kv_cache_dtype,
  128. k_scale,
  129. v_scale,
  130. tp_rank,
  131. blocksparse_local_blocks,
  132. blocksparse_vert_stride,
  133. blocksparse_block_size,
  134. blocksparse_head_sliding_step,
  135. )
  136. else:
  137. # Run PagedAttention V2.
  138. assert _PARTITION_SIZE % block_size == 0
  139. tmp_output = torch.empty(
  140. size=(num_seqs, num_heads, max_num_partitions, head_size),
  141. dtype=output.dtype,
  142. device=output.device,
  143. )
  144. exp_sums = torch.empty(
  145. size=(num_seqs, num_heads, max_num_partitions),
  146. dtype=torch.float32,
  147. device=output.device,
  148. )
  149. max_logits = torch.empty_like(exp_sums)
  150. ops.paged_attention_v2(
  151. output,
  152. exp_sums,
  153. max_logits,
  154. tmp_output,
  155. query,
  156. key_cache,
  157. value_cache,
  158. num_kv_heads,
  159. scale,
  160. block_tables,
  161. seq_lens,
  162. block_size,
  163. max_seq_len,
  164. alibi_slopes,
  165. kv_cache_dtype,
  166. k_scale,
  167. v_scale,
  168. tp_rank,
  169. blocksparse_local_blocks,
  170. blocksparse_vert_stride,
  171. blocksparse_block_size,
  172. blocksparse_head_sliding_step,
  173. )
  174. return output
  175. @staticmethod
  176. def forward_prefix(
  177. query: torch.Tensor,
  178. key: torch.Tensor,
  179. value: torch.Tensor,
  180. kv_cache_dtype: str,
  181. key_cache: torch.Tensor,
  182. value_cache: torch.Tensor,
  183. block_tables: torch.Tensor,
  184. query_start_loc: torch.Tensor,
  185. seq_lens_tensor: torch.Tensor,
  186. context_lens: torch.Tensor,
  187. max_query_len: int,
  188. alibi_slopes: Optional[torch.Tensor],
  189. sliding_window: Optional[int],
  190. k_scale: float,
  191. v_scale: float,
  192. ) -> torch.Tensor:
  193. output = torch.empty_like(query)
  194. context_attention_fwd(
  195. query,
  196. key,
  197. value,
  198. output,
  199. kv_cache_dtype,
  200. key_cache,
  201. value_cache,
  202. block_tables,
  203. # query_start_loc is (batch_size + 1,)
  204. query_start_loc[:-1],
  205. seq_lens_tensor,
  206. context_lens,
  207. max_query_len,
  208. k_scale,
  209. v_scale,
  210. alibi_slopes,
  211. sliding_window,
  212. )
  213. return output
  214. @staticmethod
  215. def swap_blocks(
  216. src_kv_cache: torch.Tensor,
  217. dst_kv_cache: torch.Tensor,
  218. src_to_dst: torch.Tensor,
  219. ) -> None:
  220. src_key_cache = src_kv_cache[0]
  221. dst_key_cache = dst_kv_cache[0]
  222. ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  223. src_value_cache = src_kv_cache[1]
  224. dst_value_cache = dst_kv_cache[1]
  225. ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  226. @staticmethod
  227. def copy_blocks(
  228. kv_caches: List[torch.Tensor],
  229. src_to_dists: torch.Tensor,
  230. ) -> None:
  231. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  232. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  233. ops.copy_blocks(key_caches, value_caches, src_to_dists)