1
0

blocksparse_attn.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. from dataclasses import dataclass, field
  2. from typing import Any, Dict, List, Optional, Tuple, Type
  3. import torch
  4. from aphrodite.attention.backends.abstract import (AttentionBackend,
  5. AttentionImpl,
  6. AttentionMetadata,
  7. AttentionType)
  8. from aphrodite.attention.backends.utils import (CommonAttentionState,
  9. CommonMetadataBuilder)
  10. from aphrodite.attention.ops.blocksparse_attention.interface import (
  11. LocalStridedBlockSparseAttn, get_head_sliding_step)
  12. from aphrodite.attention.ops.paged_attn import PagedAttention
  13. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  14. get_tensor_model_parallel_world_size)
  15. @dataclass
  16. class BlocksparseParams:
  17. max_seqlen: int
  18. # Num q heads per tensor-parallel rank/partition
  19. num_heads: int # per TP partition
  20. # Num kv heads per tensor-parallel rank/partition
  21. num_kv_heads: int
  22. # block size used for blocksparse attention.
  23. # This is the block_size used in `local_blocks`, `vert_stride`.
  24. block_size: int
  25. # Number of blocks for local attention, i.e., number of
  26. # local attended tokens / `sparse_block_size`
  27. local_blocks: int
  28. # Attend to one block per every `vert_stride` blocks.
  29. # Controlling the sparsity
  30. vert_stride: int
  31. """
  32. If to use the same vertical stride offset for all heads,
  33. i.e., attend to the same block of tokens on all heads.
  34. By default, it is False, i.e., attention on the non-local
  35. blocks depends on the `head_idx`, that is on
  36. blocks satisfying
  37. `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
  38. where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
  39. `block_idx = position_id // sparse_block_size`.
  40. See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
  41. for more detail.
  42. """
  43. homo_head: bool = False
  44. # If within a group, the kv offsets that each q attends is the same or no.
  45. homo_head_group: bool = False
  46. # Decided by homo_head and homo_head group
  47. head_sliding_step: int = field(init=False)
  48. # range of q heads to for a TP rank
  49. active_head_range: Tuple = field(init=False)
  50. def __post_init__(self):
  51. assert self.block_size > 0
  52. assert self.local_blocks >= 0
  53. assert self.vert_stride >= 1
  54. assert self.num_heads % self.num_kv_heads == 0
  55. tp_size = get_tensor_model_parallel_world_size()
  56. tp_rank = get_tensor_model_parallel_rank()
  57. total_heads = tp_size * self.num_heads
  58. total_kv_heads = tp_size * self.num_kv_heads
  59. if self.homo_head:
  60. self.head_sliding_step = 0
  61. elif self.homo_head_group:
  62. head_sliding_step = get_head_sliding_step(total_kv_heads,
  63. self.vert_stride)
  64. # negative indicates sliding along kv heads, i.e., homo q group
  65. self.head_sliding_step = -head_sliding_step
  66. else:
  67. self.head_sliding_step = get_head_sliding_step(
  68. total_heads, self.vert_stride)
  69. self.active_head_range = (
  70. tp_rank * self.num_heads,
  71. (tp_rank + 1) * self.num_heads,
  72. )
  73. class BlocksparseFlashAttentionBackend(AttentionBackend):
  74. @staticmethod
  75. def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
  76. return BlocksparseFlashAttentionImpl
  77. @staticmethod
  78. def get_metadata_cls() -> Type["AttentionMetadata"]:
  79. return BlocksparseFlashAttentionMetadata
  80. @staticmethod
  81. def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
  82. return BlocksparseFlashAttentionMetadataBuilder
  83. @staticmethod
  84. def get_state_cls() -> Type["CommonAttentionState"]:
  85. return CommonAttentionState
  86. @staticmethod
  87. def get_kv_cache_shape(
  88. num_blocks: int,
  89. block_size: int,
  90. num_kv_heads: int,
  91. head_size: int,
  92. ) -> Tuple[int, ...]:
  93. return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
  94. num_kv_heads, head_size)
  95. @staticmethod
  96. def swap_blocks(
  97. src_kv_cache: torch.Tensor,
  98. dst_kv_cache: torch.Tensor,
  99. src_to_dst: Dict[int, int],
  100. ) -> None:
  101. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  102. @staticmethod
  103. def copy_blocks(
  104. kv_caches: List[torch.Tensor],
  105. src_to_dists: Dict[int, List[int]],
  106. ) -> None:
  107. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  108. @dataclass
  109. class BlocksparseFlashAttentionMetadata(AttentionMetadata):
  110. """A copy of Metadata for FlashAttentionBackend,
  111. to avoid having to install flash_attn.
  112. NOTE: Any python object stored here is not updated when it is
  113. cuda-graph replayed. If you have values that need to be changed
  114. dynamically, it should be stored in tensor. The tensor has to be
  115. updated from `CUDAGraphRunner.forward` API.
  116. """
  117. # (batch_size,). The sequence length per sequence. Sequence length means
  118. # the computed tokens + new tokens None if it is a decoding.
  119. seq_lens: Optional[List[int]]
  120. # seq_lens stored as a tensor.
  121. seq_lens_tensor: Optional[torch.Tensor]
  122. # NOTE(sang): Definition of context_len, query_len, and seq_len.
  123. # |---------- N-1 iteration --------|
  124. # |---------------- N iteration ---------------------|
  125. # |- tokenA -|......................|-- newTokens ---|
  126. # |---------- context_len ----------|
  127. # |-------------------- seq_len ----------------------|
  128. # |-- query_len ---|
  129. # Maximum query length in the batch. None for decoding.
  130. max_query_len: Optional[int]
  131. # Maximum sequence length among prefill batch. 0 if there are decoding
  132. # requests only.
  133. max_prefill_seq_len: int
  134. # Maximum sequence length among decode batch. 0 if there are prefill
  135. # requests only.
  136. max_decode_seq_len: int
  137. # (batch_size + 1,). The cumulative subquery lengths of the sequences in
  138. # the batch, used to index into subquery. E.g., if the subquery length
  139. # is [4, 6], it is [0, 4, 10].
  140. query_start_loc: Optional[torch.Tensor]
  141. # (batch_size + 1,). The cumulative sequence lengths of the sequences in
  142. # the batch, used to index into sequence. E.g., if the sequence length is
  143. # [4, 6], it is [0, 4, 10].
  144. seq_start_loc: Optional[torch.Tensor]
  145. # (batch_size,) A tensor of context lengths (tokens that are computed
  146. # so far).
  147. context_lens_tensor: Optional[torch.Tensor]
  148. # (batch_size, max_blocks_per_seq).
  149. # Block addresses per sequence. (Seq id -> list of physical block)
  150. # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
  151. # in the kv cache. Each block can contain up to block_size tokens.
  152. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
  153. # captured.
  154. block_tables: Optional[torch.Tensor]
  155. # Whether or not if cuda graph is enabled.
  156. # Cuda-graph is currently enabled for decoding only.
  157. # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
  158. use_cuda_graph: bool
  159. _cached_prefill_metadata: Optional[
  160. "BlocksparseFlashAttentionMetadata"] = None
  161. _cached_decode_metadata: Optional[
  162. "BlocksparseFlashAttentionMetadata"] = None
  163. @property
  164. def prefill_metadata(
  165. self) -> Optional["BlocksparseFlashAttentionMetadata"]:
  166. if self.num_prefills == 0:
  167. return None
  168. if self._cached_prefill_metadata is not None:
  169. return self._cached_prefill_metadata
  170. assert self.seq_lens is not None
  171. assert self.seq_lens_tensor is not None
  172. assert self.query_start_loc is not None
  173. assert self.context_lens_tensor is not None
  174. assert self.block_tables is not None
  175. assert self.seq_start_loc is not None
  176. self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
  177. num_prefills=self.num_prefills,
  178. num_prefill_tokens=self.num_prefill_tokens,
  179. num_decode_tokens=0,
  180. slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
  181. seq_lens=self.seq_lens[:self.num_prefills],
  182. seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
  183. max_query_len=self.max_query_len,
  184. max_prefill_seq_len=self.max_prefill_seq_len,
  185. max_decode_seq_len=0,
  186. query_start_loc=self.query_start_loc[:self.num_prefills + 1],
  187. seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
  188. context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
  189. block_tables=self.block_tables[:self.num_prefills],
  190. use_cuda_graph=False,
  191. )
  192. return self._cached_prefill_metadata
  193. @property
  194. def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
  195. if self.num_decode_tokens == 0:
  196. return None
  197. if self._cached_decode_metadata is not None:
  198. return self._cached_decode_metadata
  199. assert self.block_tables is not None
  200. assert self.seq_lens_tensor is not None
  201. self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
  202. num_prefills=0,
  203. num_prefill_tokens=0,
  204. num_decode_tokens=self.num_decode_tokens,
  205. slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
  206. seq_lens=None,
  207. seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
  208. max_query_len=None,
  209. max_prefill_seq_len=0,
  210. max_decode_seq_len=self.max_decode_seq_len,
  211. query_start_loc=None,
  212. seq_start_loc=None,
  213. context_lens_tensor=None,
  214. block_tables=self.block_tables[self.num_prefills:],
  215. use_cuda_graph=self.use_cuda_graph,
  216. )
  217. return self._cached_decode_metadata
  218. class BlocksparseFlashAttentionMetadataBuilder(
  219. CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
  220. _metadata_cls = BlocksparseFlashAttentionMetadata
  221. class BlocksparseFlashAttentionImpl(AttentionImpl):
  222. """
  223. If the input tensors contain prompt tokens, the layout is as follows:
  224. |<--------------- num_prompt_tokens -------------->|
  225. |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
  226. Otherwise, the layout is as follows:
  227. |<------------------ num_generation_tokens (M) ----------------->|
  228. |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
  229. Generation tokens can contain padding when cuda-graph is used.
  230. Currently, prompt tokens don't contain any padding.
  231. The prompts might have different lengths, while the generation tokens
  232. always have length 1.
  233. """
  234. def __init__(
  235. self,
  236. num_heads: int,
  237. head_size: int,
  238. scale: float,
  239. num_kv_heads: int,
  240. alibi_slopes: Optional[List[float]],
  241. sliding_window: Optional[int],
  242. kv_cache_dtype: str,
  243. blocksparse_params: Optional[Dict[str, Any]] = None,
  244. logits_soft_cap: Optional[float] = None,
  245. ) -> None:
  246. assert blocksparse_params is not None
  247. assert alibi_slopes is None, ValueError(
  248. "Alibi not support for blocksparse flash attention.")
  249. assert sliding_window is None, ValueError(
  250. "sliding_window is invalid for blocksparse attention.")
  251. assert logits_soft_cap is None, ValueError(
  252. "logits_soft_cap is invalid for blocksparse attention.")
  253. if "num_heads" not in blocksparse_params:
  254. blocksparse_params["num_heads"] = num_heads
  255. if "num_kv_heads" not in blocksparse_params:
  256. blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
  257. self.blocksparse_params = BlocksparseParams(**blocksparse_params)
  258. self.kv_cache_dtype = kv_cache_dtype
  259. self.num_heads = num_heads
  260. self.head_size = head_size
  261. self.scale = float(scale)
  262. self.alibi_slopes = alibi_slopes
  263. self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
  264. assert self.num_heads % self.num_kv_heads == 0
  265. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  266. self.local_blocks = self.blocksparse_params.local_blocks
  267. self.vert_stride = self.blocksparse_params.vert_stride
  268. self.sparse_block_size = self.blocksparse_params.block_size
  269. self.head_sliding_step = self.blocksparse_params.head_sliding_step
  270. suppored_head_sizes = PagedAttention.get_supported_head_sizes()
  271. if head_size not in suppored_head_sizes:
  272. raise ValueError(
  273. f"Head size {head_size} is not supported by PagedAttention. "
  274. f"Supported head sizes are: {suppored_head_sizes}.")
  275. self.tp_size = get_tensor_model_parallel_world_size()
  276. self.tp_rank = get_tensor_model_parallel_rank()
  277. total_num_heads = num_heads * self.tp_size
  278. self.bs_attn = LocalStridedBlockSparseAttn(
  279. total_num_heads,
  280. self.blocksparse_params.max_seqlen,
  281. self.blocksparse_params.local_blocks,
  282. self.blocksparse_params.vert_stride,
  283. self.blocksparse_params.block_size,
  284. homo_head=self.blocksparse_params.homo_head,
  285. active_head_range=self.blocksparse_params.active_head_range,
  286. )
  287. def forward(
  288. self,
  289. query: torch.Tensor,
  290. key: torch.Tensor,
  291. value: torch.Tensor,
  292. kv_cache: torch.Tensor,
  293. attn_metadata: BlocksparseFlashAttentionMetadata,
  294. k_scale: float = 1.0,
  295. v_scale: float = 1.0,
  296. attn_type: AttentionType = AttentionType.DECODER,
  297. ) -> torch.Tensor:
  298. """Forward pass with FlashAttention and PagedAttention.
  299. Args:
  300. query: shape = [num_tokens, num_heads * head_size]
  301. key: shape = [num_tokens, num_kv_heads * head_size]
  302. value: shape = [num_tokens, num_kv_heads * head_size]
  303. kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
  304. attn_metadata: Metadata for attention.
  305. Returns:
  306. shape = [num_tokens, num_heads * head_size]
  307. """
  308. if attn_type != AttentionType.DECODER:
  309. raise NotImplementedError("Encoder self-attention and "
  310. "encoder/decoder cross-attention "
  311. "are not implemented for "
  312. "BlocksparseFlashAttentionImpl")
  313. num_tokens, hidden_size = query.shape
  314. # Reshape the query, key, and value tensors.
  315. query = query.view(-1, self.num_heads, self.head_size)
  316. key = key.view(-1, self.num_kv_heads, self.head_size)
  317. value = value.view(-1, self.num_kv_heads, self.head_size)
  318. if kv_cache is not None:
  319. key_cache, value_cache = PagedAttention.split_kv_cache(
  320. kv_cache, self.num_kv_heads, self.head_size)
  321. # Reshape the input keys and values and store them in the cache.
  322. # If kv_cache is not provided, the new key and value tensors are
  323. # not cached. This happens during the initial memory profiling run.
  324. PagedAttention.write_to_paged_cache(
  325. key,
  326. value,
  327. key_cache,
  328. value_cache,
  329. attn_metadata.slot_mapping,
  330. self.kv_cache_dtype,
  331. k_scale,
  332. v_scale,
  333. )
  334. if prefill_meta := attn_metadata.prefill_metadata:
  335. # Prompt run.
  336. # normal attention
  337. # When block_tables are not filled, it means q and k are the
  338. # prompt, and they have the same length.
  339. assert kv_cache is None \
  340. or prefill_meta.block_tables is None \
  341. or prefill_meta.block_tables.numel() == 0, \
  342. "Does not support prefix-enabled attention."
  343. output = self.bs_attn(
  344. q=query,
  345. k=key,
  346. v=value,
  347. cu_seqlens_q=prefill_meta.seq_start_loc,
  348. cu_seqlens_k=prefill_meta.seq_start_loc,
  349. sm_scale=self.scale,
  350. )
  351. if decode_meta := attn_metadata.decode_metadata:
  352. # Decoding run.
  353. output = PagedAttention.forward_decode(
  354. query,
  355. key_cache,
  356. value_cache,
  357. decode_meta.block_tables,
  358. decode_meta.seq_lens_tensor,
  359. self.blocksparse_params.max_seqlen,
  360. self.kv_cache_dtype,
  361. self.num_kv_heads,
  362. self.scale,
  363. self.alibi_slopes,
  364. k_scale,
  365. v_scale,
  366. tp_rank=self.tp_rank,
  367. blocksparse_local_blocks=self.local_blocks,
  368. blocksparse_vert_stride=self.vert_stride,
  369. blocksparse_block_size=self.sparse_block_size,
  370. blocksparse_head_sliding_step=self.head_sliding_step,
  371. )
  372. # Reshape the output tensor.
  373. return output.view(num_tokens, hidden_size)