paged_attn.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. from dataclasses import dataclass
  2. from typing import List, Optional, Tuple
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.attention.ops.prefix_prefill import context_attention_fwd
  6. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
  7. _PARTITION_SIZE = 512
  8. @dataclass
  9. class PagedAttentionMetadata:
  10. """Metadata for PagedAttention."""
  11. # (batch_size,). The length of sequences (entire tokens seen so far) per
  12. # sequence.
  13. seq_lens_tensor: Optional[torch.Tensor]
  14. # Maximum sequence length in the batch. 0 if it is prefill-only batch.
  15. max_decode_seq_len: int
  16. # (batch_size, max_blocks_per_seq).
  17. # Block addresses per sequence. (Seq id -> list of physical block)
  18. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  19. # in the kv cache. Each block can contain up to block_size tokens.
  20. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  21. # captured.
  22. block_tables: Optional[torch.Tensor]
  23. class PagedAttention:
  24. @staticmethod
  25. def get_supported_head_sizes() -> List[int]:
  26. return [64, 80, 96, 112, 128, 192, 256]
  27. @staticmethod
  28. def get_kv_cache_shape(
  29. num_blocks: int,
  30. block_size: int,
  31. num_kv_heads: int,
  32. head_size: int,
  33. ) -> Tuple[int, ...]:
  34. return (2, num_blocks, block_size * num_kv_heads * head_size)
  35. @staticmethod
  36. def split_kv_cache(
  37. kv_cache: torch.Tensor,
  38. num_kv_heads: int,
  39. head_size: int,
  40. ) -> Tuple[torch.Tensor, torch.Tensor]:
  41. x = 16 // kv_cache.element_size()
  42. num_blocks = kv_cache.shape[1]
  43. key_cache = kv_cache[0]
  44. key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
  45. -1, x)
  46. value_cache = kv_cache[1]
  47. value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
  48. return key_cache, value_cache
  49. @staticmethod
  50. def write_to_paged_cache(
  51. key: torch.Tensor,
  52. value: torch.Tensor,
  53. key_cache: torch.Tensor,
  54. value_cache: torch.Tensor,
  55. slot_mapping: torch.Tensor,
  56. kv_cache_dtype: str,
  57. k_scale: float,
  58. v_scale: float,
  59. ) -> None:
  60. ops.reshape_and_cache(
  61. key,
  62. value,
  63. key_cache,
  64. value_cache,
  65. slot_mapping.flatten(),
  66. kv_cache_dtype,
  67. k_scale,
  68. v_scale,
  69. )
  70. @staticmethod
  71. def forward_decode(
  72. query: torch.Tensor,
  73. key_cache: torch.Tensor,
  74. value_cache: torch.Tensor,
  75. block_tables: torch.Tensor,
  76. seq_lens: torch.Tensor,
  77. max_seq_len: int,
  78. kv_cache_dtype: str,
  79. num_kv_heads: int,
  80. scale: float,
  81. alibi_slopes: Optional[torch.Tensor],
  82. k_scale: float,
  83. v_scale: float,
  84. tp_rank: int = 0,
  85. blocksparse_local_blocks: int = 0,
  86. blocksparse_vert_stride: int = 0,
  87. blocksparse_block_size: int = 64,
  88. blocksparse_head_sliding_step: int = 0,
  89. ) -> torch.Tensor:
  90. if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
  91. # use blocksparse paged attention
  92. block_size = value_cache.size(-1)
  93. assert (blocksparse_block_size > 0 and
  94. blocksparse_block_size % block_size == 0), \
  95. (f"{blocksparse_block_size=} needs to be a multiple of"
  96. f"{block_size=} used in block_tables.")
  97. output = torch.empty_like(query)
  98. block_size = value_cache.shape[3]
  99. num_seqs, num_heads, head_size = query.shape
  100. max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
  101. _PARTITION_SIZE)
  102. # NOTE: We use a simple heuristic to decide whether to use
  103. # PagedAttention V1 or V2. If the number of partitions is 1, we use
  104. # V1 to avoid the overhead of reduction. Also, if the number of
  105. # sequences or heads is large, we use V1 since there is enough work
  106. # to parallelize.
  107. # TODO: Tune this heuristic.
  108. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
  109. use_v1 = (max_seq_len <= 8192
  110. and (max_num_partitions == 1 or num_seqs * num_heads > 512))
  111. if use_v1:
  112. # Run PagedAttention V1.
  113. ops.paged_attention_v1(
  114. output,
  115. query,
  116. key_cache,
  117. value_cache,
  118. num_kv_heads,
  119. scale,
  120. block_tables,
  121. seq_lens,
  122. block_size,
  123. max_seq_len,
  124. alibi_slopes,
  125. kv_cache_dtype,
  126. k_scale,
  127. v_scale,
  128. tp_rank,
  129. blocksparse_local_blocks,
  130. blocksparse_vert_stride,
  131. blocksparse_block_size,
  132. blocksparse_head_sliding_step,
  133. )
  134. else:
  135. # Run PagedAttention V2.
  136. assert _PARTITION_SIZE % block_size == 0
  137. tmp_output = torch.empty(
  138. size=(num_seqs, num_heads, max_num_partitions, head_size),
  139. dtype=output.dtype,
  140. device=output.device,
  141. )
  142. exp_sums = torch.empty(
  143. size=(num_seqs, num_heads, max_num_partitions),
  144. dtype=torch.float32,
  145. device=output.device,
  146. )
  147. max_logits = torch.empty_like(exp_sums)
  148. ops.paged_attention_v2(
  149. output,
  150. exp_sums,
  151. max_logits,
  152. tmp_output,
  153. query,
  154. key_cache,
  155. value_cache,
  156. num_kv_heads,
  157. scale,
  158. block_tables,
  159. seq_lens,
  160. block_size,
  161. max_seq_len,
  162. alibi_slopes,
  163. kv_cache_dtype,
  164. k_scale,
  165. v_scale,
  166. tp_rank,
  167. blocksparse_local_blocks,
  168. blocksparse_vert_stride,
  169. blocksparse_block_size,
  170. blocksparse_head_sliding_step,
  171. )
  172. return output
  173. @staticmethod
  174. def forward_prefix(
  175. query: torch.Tensor,
  176. key: torch.Tensor,
  177. value: torch.Tensor,
  178. key_cache: torch.Tensor,
  179. value_cache: torch.Tensor,
  180. block_tables: torch.Tensor,
  181. query_start_loc: torch.Tensor,
  182. seq_lens_tensor: torch.Tensor,
  183. context_lens: torch.Tensor,
  184. max_query_len: int,
  185. alibi_slopes: Optional[torch.Tensor],
  186. sliding_window: Optional[int],
  187. ) -> torch.Tensor:
  188. output = torch.empty_like(query)
  189. context_attention_fwd(
  190. query,
  191. key,
  192. value,
  193. output,
  194. key_cache,
  195. value_cache,
  196. block_tables,
  197. # query_start_loc is (batch_size + 1,)
  198. query_start_loc[:-1],
  199. seq_lens_tensor,
  200. context_lens,
  201. max_query_len,
  202. alibi_slopes,
  203. sliding_window,
  204. )
  205. return output
  206. @staticmethod
  207. def swap_blocks(
  208. src_kv_cache: torch.Tensor,
  209. dst_kv_cache: torch.Tensor,
  210. src_to_dst: torch.Tensor,
  211. ) -> None:
  212. src_key_cache = src_kv_cache[0]
  213. dst_key_cache = dst_kv_cache[0]
  214. ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
  215. src_value_cache = src_kv_cache[1]
  216. dst_value_cache = dst_kv_cache[1]
  217. ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
  218. @staticmethod
  219. def copy_blocks(
  220. kv_caches: List[torch.Tensor],
  221. src_to_dists: torch.Tensor,
  222. ) -> None:
  223. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  224. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  225. ops.copy_blocks(key_caches, value_caches, src_to_dists)