flashinfer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
  3. try:
  4. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  5. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  6. from vllm_flash_attn import flash_attn_varlen_func
  7. except ImportError:
  8. flash_attn_varlen_func = None
  9. BatchDecodeWithPagedKVCacheWrapper = None
  10. BatchPrefillWithPagedKVCacheWrapper = None
  11. import torch
  12. from aphrodite import _custom_ops as ops
  13. from aphrodite.attention.backends.abstract import (AttentionBackend,
  14. AttentionImpl,
  15. AttentionMetadata,
  16. AttentionMetadataBuilder,
  17. AttentionType)
  18. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  19. compute_slot_mapping,
  20. compute_slot_mapping_start_idx,
  21. is_block_tables_empty)
  22. from aphrodite.attention.ops.paged_attn import PagedAttention
  23. from aphrodite.common.sequence import SequenceGroupMetadata
  24. from aphrodite.common.utils import (get_kv_cache_torch_dtype,
  25. make_tensor_with_pad)
  26. if TYPE_CHECKING:
  27. from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
  28. ModelInputForGPUBuilder)
  29. class FlashInferBackend(AttentionBackend):
  30. @staticmethod
  31. def get_name() -> str:
  32. return "flashinfer"
  33. @staticmethod
  34. def get_impl_cls() -> Type["FlashInferImpl"]:
  35. return FlashInferImpl
  36. @staticmethod
  37. def get_metadata_cls() -> Type["AttentionMetadata"]:
  38. return FlashInferMetadata
  39. @staticmethod
  40. def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
  41. return FlashInferMetadataBuilder
  42. @staticmethod
  43. def get_kv_cache_shape(
  44. num_blocks: int,
  45. block_size: int,
  46. num_kv_heads: int,
  47. head_size: int,
  48. ) -> Tuple[int, ...]:
  49. return (num_blocks, 2, block_size, num_kv_heads, head_size)
  50. @staticmethod
  51. def swap_blocks(
  52. src_kv_cache: torch.Tensor,
  53. dst_kv_cache: torch.Tensor,
  54. src_to_dst: torch.Tensor,
  55. ) -> None:
  56. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  57. @staticmethod
  58. def copy_blocks(
  59. kv_caches: List[torch.Tensor],
  60. src_to_dists: torch.Tensor,
  61. ) -> None:
  62. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  63. @staticmethod
  64. def get_supported_head_sizes() -> List[int]:
  65. return [64, 128, 256]
  66. @dataclass
  67. class FlashInferMetadata(AttentionMetadata):
  68. # Maximum sequence length among prefill batch. 0 if there are decoding
  69. # requests only.
  70. max_prefill_seq_len: int
  71. use_cuda_graph: bool = True
  72. prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
  73. decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
  74. # Metadata for the prefill stage
  75. seq_start_loc: Optional[torch.Tensor] = None
  76. query_start_loc: Optional[torch.Tensor] = None
  77. block_tables: Optional[torch.Tensor] = None
  78. # An example for paged_kv_indices, paged_kv_indptr:
  79. # request 1, page indices [0, 5, 8]
  80. # request 2, page indices [1, 6, 7]
  81. # request 3, page indices [3, 4]
  82. # paged_kv_indices is a concatenation of page indices of all requests:
  83. # [0, 5, 8, 1, 6, 7, 3, 4]
  84. # paged_kv_indptr is used to index into paged_kv_indices:
  85. # [0, 3, 6, 8]
  86. # The indptr of the paged kv cache, shape: [batch_size + 1]
  87. paged_kv_indptr: Optional[torch.Tensor] = None
  88. # The page indices of the paged kv cache
  89. paged_kv_indices: Optional[torch.Tensor] = None
  90. # The number of entries in the last page of each request in
  91. # the paged kv cache, shape: [batch_size]
  92. paged_kv_last_page_len: Optional[torch.Tensor] = None
  93. # The number of query/output heads
  94. num_qo_heads: Optional[int] = None
  95. # The number of key/value heads
  96. num_kv_heads: Optional[int] = None
  97. # The dimension of the attention heads
  98. head_dim: Optional[int] = None
  99. # Block size of Aphrodite
  100. page_size: Optional[int] = None
  101. # The data type of the paged kv cache
  102. data_type: torch.dtype = None
  103. device: torch.device = torch.device("cuda")
  104. logits_soft_cap: Optional[float] = None
  105. def __post_init__(self):
  106. # Refer to
  107. # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
  108. supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
  109. if self.head_dim is not None and self.head_dim \
  110. not in supported_head_sizes:
  111. raise ValueError(
  112. f"Only {supported_head_sizes} are supported for head_dim,",
  113. f"received {self.head_dim}.")
  114. def begin_forward(self):
  115. if self.num_prefill_tokens > 0:
  116. if self.paged_kv_indices is None:
  117. return
  118. assert self.prefill_wrapper is not None
  119. assert self.paged_kv_indices is not None
  120. assert self.paged_kv_indptr is not None
  121. assert self.paged_kv_last_page_len is not None
  122. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  123. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  124. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  125. self.device)
  126. self.prefill_wrapper.end_forward()
  127. self.prefill_wrapper.begin_forward(
  128. self.query_start_loc, self.paged_kv_indptr,
  129. self.paged_kv_indices, self.paged_kv_last_page_len,
  130. self.num_qo_heads, self.num_kv_heads, self.head_dim,
  131. self.page_size)
  132. else:
  133. if not self.use_cuda_graph:
  134. assert self.paged_kv_indices is not None
  135. assert self.paged_kv_indptr is not None
  136. assert self.paged_kv_last_page_len is not None
  137. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  138. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  139. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  140. self.device)
  141. assert self.decode_wrapper is not None
  142. self.decode_wrapper.end_forward()
  143. self.decode_wrapper.begin_forward(
  144. self.paged_kv_indptr,
  145. self.paged_kv_indices,
  146. self.paged_kv_last_page_len,
  147. self.num_qo_heads,
  148. self.num_kv_heads,
  149. self.head_dim,
  150. self.page_size,
  151. # Disable flashinfer's pos encoding and use Aphrodite's rope.
  152. pos_encoding_mode="NONE",
  153. data_type=self.data_type)
  154. def asdict_zerocopy(self,
  155. skip_fields: Optional[Set[str]] = None
  156. ) -> Dict[str, Any]:
  157. if skip_fields is None:
  158. skip_fields = set()
  159. # We need to skip the prefill/decode_wrapper field since it cannot be
  160. # broadcasted with nccl when TP is enabled.
  161. skip_fields.add('prefill_wrapper')
  162. skip_fields.add('decode_wrapper')
  163. return super().asdict_zerocopy(skip_fields)
  164. @property
  165. def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
  166. # Currently chunked prefill is not supported
  167. if self.num_decode_tokens == 0:
  168. assert self.num_prefills > 0
  169. return self
  170. return None
  171. @property
  172. def decode_metadata(self) -> Optional["FlashInferMetadata"]:
  173. # Currently chunked prefill is not supported
  174. if self.num_prefills > 0:
  175. assert self.num_decode_tokens == 0
  176. return None
  177. return self
  178. class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
  179. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  180. self.slot_mapping: List[int] = []
  181. self.prefill_seq_lens: List[int] = []
  182. self.context_lens: List[int] = []
  183. self.block_tables: List[List[int]] = []
  184. self.curr_seq_lens: List[int] = []
  185. self.num_prefills = 0
  186. self.num_prefill_tokens = 0
  187. self.num_decode_tokens = 0
  188. self.sliding_window = input_builder.sliding_window
  189. self.block_size = input_builder.block_size
  190. self.use_v2_block_manager = (
  191. input_builder.scheduler_config.use_v2_block_manager)
  192. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  193. # for the precise definition of the following fields.
  194. # An example:
  195. # request 1, page indices [0, 5, 8]
  196. # request 2, page indices [1, 6, 7]
  197. # request 3, page indices [3, 4]
  198. # paged_kv_indices is a concatenation of page indices of all requests:
  199. # [0, 5, 8, 1, 6, 7, 3, 4]
  200. # paged_kv_indptr is used to index into paged_kv_indices:
  201. # [0, 3, 6, 8]
  202. self.paged_kv_indices: List[int] = []
  203. # 0 at the beginning of paged_kv_indptr indicates the start of the
  204. # first request’s page indices in the paged_kv_indices list.
  205. self.paged_kv_indptr: List[int] = [0]
  206. # paged_kv_last_page_len is the length of the last page of each request
  207. self.paged_kv_last_page_len: List[int] = []
  208. def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
  209. token_lens: List[int], seq_lens: List[int],
  210. curr_seq_lens: List[int], query_lens: List[int],
  211. context_lens: List[int],
  212. curr_sliding_window_blocks: List[int],
  213. prefix_cache_hit: bool, chunked_prefill_enabled: bool):
  214. """Add a sequence group to the metadata. Specifically update/append
  215. 1. context length.
  216. 2. block table.
  217. 3. slot mapping.
  218. """
  219. is_prompt = seq_group_metadata.is_prompt
  220. block_tables = seq_group_metadata.block_tables
  221. computed_block_nums = seq_group_metadata.computed_block_nums
  222. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  223. curr_sliding_window_block) in zip(
  224. seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
  225. curr_seq_lens, query_lens, context_lens,
  226. curr_sliding_window_blocks):
  227. self.context_lens.append(context_len)
  228. if is_prompt:
  229. self.num_prefills += 1
  230. self.num_prefill_tokens += token_len
  231. self.prefill_seq_lens.append(seq_len)
  232. else:
  233. assert query_len == 1, (
  234. "seq_len: {}, context_len: {}, query_len: {}".format(
  235. seq_len, context_len, query_len))
  236. self.num_decode_tokens += query_len
  237. self.curr_seq_lens.append(curr_seq_len)
  238. # Compute block table.
  239. # TODO(sang): Combine chunked prefill and prefix caching by
  240. # only allowing multiple of block_size chunk size.
  241. # NOTE: This only works for oooooooxxx style attention.
  242. block_table = []
  243. if prefix_cache_hit:
  244. block_table = computed_block_nums
  245. elif ((chunked_prefill_enabled or not is_prompt)
  246. and block_tables is not None):
  247. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  248. self.block_tables.append(block_table)
  249. is_profile_run = is_block_tables_empty(block_tables)
  250. # Compute slot mapping.
  251. start_idx = compute_slot_mapping_start_idx(
  252. is_prompt, query_len, context_len, self.sliding_window,
  253. self.use_v2_block_manager)
  254. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  255. seq_len, context_len, start_idx,
  256. self.block_size,
  257. seq_group_metadata.block_tables)
  258. # It is not necessary to add paged_kv_indices, paged_kv_indptr,
  259. # and paged_kv_last_page_len for profile run because we will
  260. # create dummy inputs.
  261. if is_profile_run:
  262. return
  263. # Get the number of valid blocks based on sequence length.
  264. # If seq_len = 16, block_size = 16,
  265. # block_table_bound is 1 with 1 valid block.
  266. # If seq_len = 15, block_size = 16,
  267. # block_table_bound is 0 + 1 with 1 valid block.
  268. block_table_bound = seq_len // self.block_size + 1 \
  269. if seq_len % self.block_size != 0 \
  270. else seq_len // self.block_size
  271. block_table = block_tables[seq_id]
  272. self.paged_kv_indices.extend(block_table[:block_table_bound])
  273. self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
  274. block_table_bound)
  275. last_page_len = seq_len % self.block_size
  276. if last_page_len == 0:
  277. last_page_len = self.block_size
  278. self.paged_kv_last_page_len.append(last_page_len)
  279. def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
  280. cuda_graph_pad_size: int, batch_size: int):
  281. device = runner.device
  282. use_captured_graph = cuda_graph_pad_size != -1
  283. max_query_len = max(query_lens)
  284. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  285. num_decode_tokens = self.num_decode_tokens
  286. if use_captured_graph:
  287. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  288. self.block_tables.extend([] * cuda_graph_pad_size)
  289. num_decode_tokens = batch_size + cuda_graph_pad_size
  290. # The shape of graph_block_tables is
  291. # [max batch size, max context len // block size].
  292. input_block_tables = runner.graph_block_tables[:batch_size]
  293. for i, block_table in enumerate(self.block_tables):
  294. if block_table:
  295. input_block_tables[i, :len(block_table)] = block_table
  296. block_tables = torch.tensor(input_block_tables, device=device)
  297. last_paged_kv_indptr = self.paged_kv_indptr[-1]
  298. self.paged_kv_indptr.extend([last_paged_kv_indptr] *
  299. cuda_graph_pad_size)
  300. self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
  301. else:
  302. max_block_table_len = max(
  303. len(block_table) for block_table in self.block_tables)
  304. block_tables = make_tensor_with_pad(
  305. self.block_tables,
  306. max_len=max_block_table_len,
  307. pad=0,
  308. dtype=torch.int,
  309. device=device,
  310. )
  311. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  312. seq_lens_tensor = torch.tensor(seq_lens,
  313. dtype=torch.int,
  314. device=device)
  315. query_lens_tensor = torch.tensor(query_lens,
  316. dtype=torch.long,
  317. device=device)
  318. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  319. dtype=torch.int32,
  320. device=device)
  321. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  322. dtype=torch.int32,
  323. device=device)
  324. torch.cumsum(seq_lens_tensor,
  325. dim=0,
  326. dtype=seq_start_loc.dtype,
  327. out=seq_start_loc[1:])
  328. torch.cumsum(query_lens_tensor,
  329. dim=0,
  330. dtype=query_start_loc.dtype,
  331. out=query_start_loc[1:])
  332. slot_mapping_tensor = torch.tensor(self.slot_mapping,
  333. dtype=torch.long,
  334. device=device)
  335. logits_soft_cap = getattr(runner.model_config.hf_config,
  336. "attn_logit_softcapping", None)
  337. if len(self.paged_kv_indptr) > 0:
  338. paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
  339. device="cpu",
  340. dtype=torch.int)
  341. paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
  342. device="cpu",
  343. dtype=torch.int)
  344. paged_kv_last_page_len_tensor = torch.tensor(
  345. self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
  346. else:
  347. paged_kv_indices_tensor = None
  348. paged_kv_indptr_tensor = None
  349. paged_kv_last_page_len_tensor = None
  350. kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype,
  351. runner.model_config.dtype)
  352. return FlashInferMetadata(
  353. num_prefills=self.num_prefills,
  354. slot_mapping=slot_mapping_tensor,
  355. num_prefill_tokens=self.num_prefill_tokens,
  356. num_decode_tokens=num_decode_tokens,
  357. max_prefill_seq_len=max_prefill_seq_len,
  358. block_tables=block_tables,
  359. paged_kv_indptr=paged_kv_indptr_tensor,
  360. paged_kv_indices=paged_kv_indices_tensor,
  361. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  362. num_qo_heads=runner.model_config.get_num_attention_heads(
  363. runner.parallel_config),
  364. num_kv_heads=runner.model_config.get_num_kv_heads(
  365. runner.parallel_config),
  366. head_dim=runner.model_config.get_head_size(),
  367. page_size=self.block_size,
  368. seq_start_loc=seq_start_loc,
  369. query_start_loc=query_start_loc,
  370. device=device,
  371. data_type=kv_cache_dtype,
  372. use_cuda_graph=use_captured_graph,
  373. logits_soft_cap=logits_soft_cap)
  374. class FlashInferImpl(AttentionImpl):
  375. def __init__(
  376. self,
  377. num_heads: int,
  378. head_size: int,
  379. scale: float,
  380. num_kv_heads: int,
  381. alibi_slopes: Optional[List[float]],
  382. sliding_window: Optional[int],
  383. kv_cache_dtype: str,
  384. blocksparse_params: Optional[Dict[str, Any]] = None,
  385. ) -> None:
  386. assert blocksparse_params is None, ValueError(
  387. "FlashInfer does not support block-sparse attention.")
  388. self.num_heads = num_heads
  389. self.head_size = head_size
  390. self.scale = float(scale)
  391. self.num_kv_heads = num_kv_heads
  392. if alibi_slopes is not None:
  393. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  394. self.alibi_slopes = alibi_slopes
  395. if sliding_window is not None:
  396. raise ValueError("Sliding window is not supported in FlashInfer.")
  397. self.sliding_window = (-1, -1)
  398. self.kv_cache_dtype = kv_cache_dtype
  399. assert self.num_heads % self.num_kv_heads == 0
  400. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  401. def forward(
  402. self,
  403. query: torch.Tensor,
  404. key: torch.Tensor,
  405. value: torch.Tensor,
  406. kv_cache: Optional[torch.Tensor],
  407. attn_metadata: FlashInferMetadata,
  408. k_scale: float = 1.0,
  409. v_scale: float = 1.0,
  410. attn_type: AttentionType = AttentionType.DECODER,
  411. ) -> torch.Tensor:
  412. assert k_scale == 1.0 and v_scale == 1.0, (
  413. "key/v_scale is not supported in FlashInfer.")
  414. if attn_type != AttentionType.DECODER:
  415. raise NotImplementedError("Encoder self-attention and "
  416. "encoder/decoder cross-attention "
  417. "are not implemented for "
  418. "FlashInferImpl")
  419. num_tokens, hidden_size = query.shape
  420. query = query.view(-1, self.num_heads, self.head_size)
  421. key = key.view(-1, self.num_kv_heads, self.head_size)
  422. value = value.view(-1, self.num_kv_heads, self.head_size)
  423. if attn_metadata.num_prefill_tokens > 0:
  424. assert attn_metadata.num_decode_tokens == 0, (
  425. "Chunked prefill is not supported with flashinfer yet.")
  426. if attn_metadata.num_decode_tokens > 0:
  427. assert attn_metadata.num_prefill_tokens == 0, (
  428. "Chunked prefill is not supported with flashinfer yet.")
  429. if kv_cache is not None:
  430. # Use the same reshape and cache kernel as flash attention.
  431. ops.reshape_and_cache_flash(
  432. key,
  433. value,
  434. kv_cache[:, 0],
  435. kv_cache[:, 1],
  436. attn_metadata.slot_mapping.flatten(),
  437. self.kv_cache_dtype,
  438. )
  439. query = query.contiguous(
  440. ) # Flashinfer requires query to be contiguous
  441. if prefill_meta := attn_metadata.prefill_metadata:
  442. # We will use flash attention for prefill
  443. # when kv_cache is not provided.
  444. # This happens when vllm runs the profiling to
  445. # determine the number of blocks.
  446. if kv_cache is None:
  447. output = flash_attn_varlen_func(
  448. q=query,
  449. k=key,
  450. v=value,
  451. cu_seqlens_q=prefill_meta.seq_start_loc,
  452. cu_seqlens_k=prefill_meta.seq_start_loc,
  453. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  454. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  455. softmax_scale=self.scale,
  456. causal=True,
  457. window_size=self.sliding_window,
  458. alibi_slopes=self.alibi_slopes,
  459. )
  460. else:
  461. assert prefill_meta is not None
  462. assert prefill_meta.prefill_wrapper is not None
  463. output = prefill_meta.prefill_wrapper.forward(
  464. query,
  465. kv_cache,
  466. logits_soft_cap=attn_metadata.logits_soft_cap,
  467. causal=True)
  468. else:
  469. assert attn_metadata.decode_metadata is not None
  470. assert attn_metadata.decode_metadata.decode_wrapper is not None
  471. output = attn_metadata.decode_metadata.decode_wrapper.forward(
  472. query,
  473. kv_cache,
  474. sm_scale=self.scale,
  475. logits_soft_cap=attn_metadata.logits_soft_cap)
  476. return output.view(num_tokens, hidden_size)