blocksparse_attn.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. from dataclasses import dataclass, field
  2. from typing import Any, Dict, List, Optional, Tuple, Type
  3. import torch
  4. from aphrodite.attention.backends.abstract import (AttentionBackend,
  5. AttentionImpl,
  6. AttentionMetadata,
  7. AttentionType)
  8. from aphrodite.attention.backends.utils import CommonMetadataBuilder
  9. from aphrodite.attention.ops.blocksparse_attention.interface import (
  10. LocalStridedBlockSparseAttn, get_head_sliding_step)
  11. from aphrodite.attention.ops.paged_attn import PagedAttention
  12. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  13. get_tensor_model_parallel_world_size)
  14. @dataclass
  15. class BlocksparseParams:
  16. max_seqlen: int
  17. # Num q heads per tensor-parallel rank/partition
  18. num_heads: int # per TP partition
  19. # Num kv heads per tensor-parallel rank/partition
  20. num_kv_heads: int
  21. # block size used for blocksparse attention.
  22. # This is the block_size used in `local_blocks`, `vert_stride`.
  23. block_size: int
  24. # Number of blocks for local attention, i.e., number of
  25. # local attended tokens / `sparse_block_size`
  26. local_blocks: int
  27. # Attend to one block per every `vert_stride` blocks.
  28. # Controlling the sparsity
  29. vert_stride: int
  30. """
  31. If to use the same vertical stride offset for all heads,
  32. i.e., attend to the same block of tokens on all heads.
  33. By default, it is False, i.e., attention on the non-local
  34. blocks depends on the `head_idx`, that is on
  35. blocks satisfying
  36. `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
  37. where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
  38. `block_idx = position_id // sparse_block_size`.
  39. See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
  40. for more detail.
  41. """
  42. homo_head: bool = False
  43. # If within a group, the kv offsets that each q attends is the same or no.
  44. homo_head_group: bool = False
  45. # Decided by homo_head and homo_head group
  46. head_sliding_step: int = field(init=False)
  47. # range of q heads to for a TP rank
  48. active_head_range: Tuple = field(init=False)
  49. def __post_init__(self):
  50. assert self.block_size > 0
  51. assert self.local_blocks >= 0
  52. assert self.vert_stride >= 1
  53. assert self.num_heads % self.num_kv_heads == 0
  54. tp_size = get_tensor_model_parallel_world_size()
  55. tp_rank = get_tensor_model_parallel_rank()
  56. total_heads = tp_size * self.num_heads
  57. total_kv_heads = tp_size * self.num_kv_heads
  58. if self.homo_head:
  59. self.head_sliding_step = 0
  60. elif self.homo_head_group:
  61. head_sliding_step = get_head_sliding_step(total_kv_heads,
  62. self.vert_stride)
  63. # negative indicates sliding along kv heads, i.e., homo q group
  64. self.head_sliding_step = -head_sliding_step
  65. else:
  66. self.head_sliding_step = get_head_sliding_step(
  67. total_heads, self.vert_stride)
  68. self.active_head_range = (
  69. tp_rank * self.num_heads,
  70. (tp_rank + 1) * self.num_heads,
  71. )
  72. class BlocksparseFlashAttentionBackend(AttentionBackend):
  73. @staticmethod
  74. def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
  75. return BlocksparseFlashAttentionImpl
  76. @staticmethod
  77. def get_metadata_cls() -> Type["AttentionMetadata"]:
  78. return BlocksparseFlashAttentionMetadata
  79. @staticmethod
  80. def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
  81. return BlocksparseFlashAttentionMetadataBuilder
  82. @staticmethod
  83. def get_kv_cache_shape(
  84. num_blocks: int,
  85. block_size: int,
  86. num_kv_heads: int,
  87. head_size: int,
  88. ) -> Tuple[int, ...]:
  89. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  90. num_kv_heads, head_size)
  91. @staticmethod
  92. def swap_blocks(
  93. src_kv_cache: torch.Tensor,
  94. dst_kv_cache: torch.Tensor,
  95. src_to_dst: Dict[int, int],
  96. ) -> None:
  97. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  98. @staticmethod
  99. def copy_blocks(
  100. kv_caches: List[torch.Tensor],
  101. src_to_dists: Dict[int, List[int]],
  102. ) -> None:
  103. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  104. @dataclass
  105. class BlocksparseFlashAttentionMetadata(AttentionMetadata):
  106. """A copy of Metadata for FlashAttentionBackend,
  107. to avoid having to install flash_attn.
  108. NOTE: Any python object stored here is not updated when it is
  109. cuda-graph replayed. If you have values that need to be changed
  110. dynamically, it should be stored in tensor. The tensor has to be
  111. updated from `CUDAGraphRunner.forward` API.
  112. """
  113. # (batch_size,). The sequence length per sequence. Sequence length means
  114. # the computed tokens + new tokens None if it is a decoding.
  115. seq_lens: Optional[List[int]]
  116. # seq_lens stored as a tensor.
  117. seq_lens_tensor: Optional[torch.Tensor]
  118. # NOTE(sang): Definition of context_len, query_len, and seq_len.
  119. # |---------- N-1 iteration --------|
  120. # |---------------- N iteration ---------------------|
  121. # |- tokenA -|......................|-- newTokens ---|
  122. # |---------- context_len ----------|
  123. # |-------------------- seq_len ----------------------|
  124. # |-- query_len ---|
  125. # Maximum query length in the batch. None for decoding.
  126. max_query_len: Optional[int]
  127. # Maximum sequence length among prefill batch. 0 if there are decoding
  128. # requests only.
  129. max_prefill_seq_len: int
  130. # Maximum sequence length among decode batch. 0 if there are prefill
  131. # requests only.
  132. max_decode_seq_len: int
  133. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  134. # the batch, used to index into subquery. E.g., if the subquery length
  135. # is [4, 6], it is [0, 4, 10].
  136. query_start_loc: Optional[torch.Tensor]
  137. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  138. # the batch, used to index into sequence. E.g., if the sequence length is
  139. # [4, 6], it is [0, 4, 10].
  140. seq_start_loc: Optional[torch.Tensor]
  141. # (batch_size,) A tensor of context lengths (tokens that are computed
  142. # so far).
  143. context_lens_tensor: Optional[torch.Tensor]
  144. # (batch_size, max_blocks_per_seq).
  145. # Block addresses per sequence. (Seq id -> list of physical block)
  146. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  147. # in the kv cache. Each block can contain up to block_size tokens.
  148. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  149. # captured.
  150. block_tables: Optional[torch.Tensor]
  151. # Whether or not if cuda graph is enabled.
  152. # Cuda-graph is currently enabled for decoding only.
  153. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  154. use_cuda_graph: bool
  155. _cached_prefill_metadata: Optional[
  156. "BlocksparseFlashAttentionMetadata"] = None
  157. _cached_decode_metadata: Optional[
  158. "BlocksparseFlashAttentionMetadata"] = None
  159. @property
  160. def prefill_metadata(
  161. self) -> Optional["BlocksparseFlashAttentionMetadata"]:
  162. if self.num_prefills == 0:
  163. return None
  164. if self._cached_prefill_metadata is not None:
  165. return self._cached_prefill_metadata
  166. assert self.seq_lens is not None
  167. assert self.seq_lens_tensor is not None
  168. assert self.query_start_loc is not None
  169. assert self.context_lens_tensor is not None
  170. assert self.block_tables is not None
  171. assert self.seq_start_loc is not None
  172. self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
  173. num_prefills=self.num_prefills,
  174. num_prefill_tokens=self.num_prefill_tokens,
  175. num_decode_tokens=0,
  176. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  177. seq_lens=self.seq_lens[:self.num_prefills],
  178. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  179. max_query_len=self.max_query_len,
  180. max_prefill_seq_len=self.max_prefill_seq_len,
  181. max_decode_seq_len=0,
  182. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  183. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  184. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  185. block_tables=self.block_tables[:self.num_prefills],
  186. use_cuda_graph=False,
  187. )
  188. return self._cached_prefill_metadata
  189. @property
  190. def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
  191. if self.num_decode_tokens == 0:
  192. return None
  193. if self._cached_decode_metadata is not None:
  194. return self._cached_decode_metadata
  195. assert self.block_tables is not None
  196. assert self.seq_lens_tensor is not None
  197. self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
  198. num_prefills=0,
  199. num_prefill_tokens=0,
  200. num_decode_tokens=self.num_decode_tokens,
  201. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  202. seq_lens=None,
  203. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  204. max_query_len=None,
  205. max_prefill_seq_len=0,
  206. max_decode_seq_len=self.max_decode_seq_len,
  207. query_start_loc=None,
  208. seq_start_loc=None,
  209. context_lens_tensor=None,
  210. block_tables=self.block_tables[self.num_prefills:],
  211. use_cuda_graph=self.use_cuda_graph,
  212. )
  213. return self._cached_decode_metadata
  214. class BlocksparseFlashAttentionMetadataBuilder(
  215. CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
  216. _metadata_cls = BlocksparseFlashAttentionMetadata
  217. class BlocksparseFlashAttentionImpl(AttentionImpl):
  218. """
  219. If the input tensors contain prompt tokens, the layout is as follows:
  220. |<--------------- num_prompt_tokens -------------->|
  221. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  222. Otherwise, the layout is as follows:
  223. |<------------------ num_generation_tokens (M) ----------------->|
  224. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  225. Generation tokens can contain padding when cuda-graph is used.
  226. Currently, prompt tokens don't contain any padding.
  227. The prompts might have different lengths, while the generation tokens
  228. always have length 1.
  229. """
  230. def __init__(
  231. self,
  232. num_heads: int,
  233. head_size: int,
  234. scale: float,
  235. num_kv_heads: int,
  236. alibi_slopes: Optional[List[float]],
  237. sliding_window: Optional[int],
  238. kv_cache_dtype: str,
  239. blocksparse_params: Optional[Dict[str, Any]] = None,
  240. logits_soft_cap: Optional[float] = None,
  241. ) -> None:
  242. assert blocksparse_params is not None
  243. assert alibi_slopes is None, ValueError(
  244. "Alibi not support for blocksparse flash attention.")
  245. assert sliding_window is None, ValueError(
  246. "sliding_window is invalid for blocksparse attention.")
  247. assert logits_soft_cap is None, ValueError(
  248. "logits_soft_cap is invalid for blocksparse attention.")
  249. if "num_heads" not in blocksparse_params:
  250. blocksparse_params["num_heads"] = num_heads
  251. if "num_kv_heads" not in blocksparse_params:
  252. blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
  253. self.blocksparse_params = BlocksparseParams(**blocksparse_params)
  254. self.kv_cache_dtype = kv_cache_dtype
  255. self.num_heads = num_heads
  256. self.head_size = head_size
  257. self.scale = float(scale)
  258. self.alibi_slopes = alibi_slopes
  259. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  260. assert self.num_heads % self.num_kv_heads == 0
  261. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  262. self.local_blocks = self.blocksparse_params.local_blocks
  263. self.vert_stride = self.blocksparse_params.vert_stride
  264. self.sparse_block_size = self.blocksparse_params.block_size
  265. self.head_sliding_step = self.blocksparse_params.head_sliding_step
  266. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  267. if head_size not in suppored_head_sizes:
  268. raise ValueError(
  269. f"Head size {head_size} is not supported by PagedAttention. "
  270. f"Supported head sizes are: {suppored_head_sizes}.")
  271. self.tp_size = get_tensor_model_parallel_world_size()
  272. self.tp_rank = get_tensor_model_parallel_rank()
  273. total_num_heads = num_heads * self.tp_size
  274. self.bs_attn = LocalStridedBlockSparseAttn(
  275. total_num_heads,
  276. self.blocksparse_params.max_seqlen,
  277. self.blocksparse_params.local_blocks,
  278. self.blocksparse_params.vert_stride,
  279. self.blocksparse_params.block_size,
  280. homo_head=self.blocksparse_params.homo_head,
  281. active_head_range=self.blocksparse_params.active_head_range,
  282. )
  283. def forward(
  284. self,
  285. query: torch.Tensor,
  286. key: torch.Tensor,
  287. value: torch.Tensor,
  288. kv_cache: torch.Tensor,
  289. attn_metadata: BlocksparseFlashAttentionMetadata,
  290. k_scale: float = 1.0,
  291. v_scale: float = 1.0,
  292. attn_type: AttentionType = AttentionType.DECODER,
  293. ) -> torch.Tensor:
  294. """Forward pass with FlashAttention and PagedAttention.
  295. Args:
  296. query: shape = [num_tokens, num_heads * head_size]
  297. key: shape = [num_tokens, num_kv_heads * head_size]
  298. value: shape = [num_tokens, num_kv_heads * head_size]
  299. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  300. attn_metadata: Metadata for attention.
  301. Returns:
  302. shape = [num_tokens, num_heads * head_size]
  303. """
  304. if attn_type != AttentionType.DECODER:
  305. raise NotImplementedError("Encoder self-attention and "
  306. "encoder/decoder cross-attention "
  307. "are not implemented for "
  308. "BlocksparseFlashAttentionImpl")
  309. num_tokens, hidden_size = query.shape
  310. # Reshape the query, key, and value tensors.
  311. query = query.view(-1, self.num_heads, self.head_size)
  312. key = key.view(-1, self.num_kv_heads, self.head_size)
  313. value = value.view(-1, self.num_kv_heads, self.head_size)
  314. if kv_cache is not None:
  315. key_cache, value_cache = PagedAttention.split_kv_cache(
  316. kv_cache, self.num_kv_heads, self.head_size)
  317. # Reshape the input keys and values and store them in the cache.
  318. # If kv_cache is not provided, the new key and value tensors are
  319. # not cached. This happens during the initial memory profiling run.
  320. PagedAttention.write_to_paged_cache(
  321. key,
  322. value,
  323. key_cache,
  324. value_cache,
  325. attn_metadata.slot_mapping,
  326. self.kv_cache_dtype,
  327. k_scale,
  328. v_scale,
  329. )
  330. if prefill_meta := attn_metadata.prefill_metadata:
  331. # Prompt run.
  332. # normal attention
  333. # When block_tables are not filled, it means q and k are the
  334. # prompt, and they have the same length.
  335. assert kv_cache is None \
  336. or prefill_meta.block_tables is None \
  337. or prefill_meta.block_tables.numel() == 0, \
  338. "Does not support prefix-enabled attention."
  339. output = self.bs_attn(
  340. q=query,
  341. k=key,
  342. v=value,
  343. cu_seqlens_q=prefill_meta.seq_start_loc,
  344. cu_seqlens_k=prefill_meta.seq_start_loc,
  345. sm_scale=self.scale,
  346. )
  347. if decode_meta := attn_metadata.decode_metadata:
  348. # Decoding run.
  349. output = PagedAttention.forward_decode(
  350. query,
  351. key_cache,
  352. value_cache,
  353. decode_meta.block_tables,
  354. decode_meta.seq_lens_tensor,
  355. self.blocksparse_params.max_seqlen,
  356. self.kv_cache_dtype,
  357. self.num_kv_heads,
  358. self.scale,
  359. self.alibi_slopes,
  360. k_scale,
  361. v_scale,
  362. tp_rank=self.tp_rank,
  363. blocksparse_local_blocks=self.local_blocks,
  364. blocksparse_vert_stride=self.vert_stride,
  365. blocksparse_block_size=self.sparse_block_size,
  366. blocksparse_head_sliding_step=self.head_sliding_step,
  367. )
  368. # Reshape the output tensor.
  369. return output.view(num_tokens, hidden_size)