flashinfer.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
  3. try:
  4. from aphrodite_flash_attn import flash_attn_varlen_func
  5. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  6. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  7. except ImportError:
  8. flash_attn_varlen_func = None
  9. BatchDecodeWithPagedKVCacheWrapper = None
  10. BatchPrefillWithPagedKVCacheWrapper = None
  11. import torch
  12. from aphrodite import _custom_ops as ops
  13. from aphrodite.attention.backends.abstract import (AttentionBackend,
  14. AttentionImpl,
  15. AttentionMetadata,
  16. AttentionMetadataBuilder,
  17. AttentionType)
  18. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  19. compute_slot_mapping,
  20. compute_slot_mapping_start_idx,
  21. is_block_tables_empty)
  22. from aphrodite.attention.ops.paged_attn import PagedAttention
  23. from aphrodite.common.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
  24. make_tensor_with_pad)
  25. if TYPE_CHECKING:
  26. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  27. class FlashInferBackend(AttentionBackend):
  28. @staticmethod
  29. def get_name() -> str:
  30. return "flashinfer"
  31. @staticmethod
  32. def get_impl_cls() -> Type["FlashInferImpl"]:
  33. return FlashInferImpl
  34. @staticmethod
  35. def get_metadata_cls() -> Type["AttentionMetadata"]:
  36. return FlashInferMetadata
  37. @staticmethod
  38. def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
  39. return FlashInferMetadataBuilder
  40. @staticmethod
  41. def get_kv_cache_shape(
  42. num_blocks: int,
  43. block_size: int,
  44. num_kv_heads: int,
  45. head_size: int,
  46. ) -> Tuple[int, ...]:
  47. return (num_blocks, 2, block_size, num_kv_heads, head_size)
  48. @staticmethod
  49. def swap_blocks(
  50. src_kv_cache: torch.Tensor,
  51. dst_kv_cache: torch.Tensor,
  52. src_to_dst: torch.Tensor,
  53. ) -> None:
  54. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  55. @staticmethod
  56. def copy_blocks(
  57. kv_caches: List[torch.Tensor],
  58. src_to_dists: torch.Tensor,
  59. ) -> None:
  60. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  61. @staticmethod
  62. def get_supported_head_sizes() -> List[int]:
  63. return [64, 128, 256]
  64. @dataclass
  65. class FlashInferMetadata(AttentionMetadata):
  66. # Maximum sequence length among prefill batch. 0 if there are decoding
  67. # requests only.
  68. max_prefill_seq_len: int
  69. use_cuda_graph: bool = True
  70. prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
  71. decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
  72. # Metadata for the prefill stage
  73. seq_start_loc: Optional[torch.Tensor] = None
  74. query_start_loc: Optional[torch.Tensor] = None
  75. block_tables: Optional[torch.Tensor] = None
  76. # An example for paged_kv_indices, paged_kv_indptr:
  77. # request 1, page indices [0, 5, 8]
  78. # request 2, page indices [1, 6, 7]
  79. # request 3, page indices [3, 4]
  80. # paged_kv_indices is a concatenation of page indices of all requests:
  81. # [0, 5, 8, 1, 6, 7, 3, 4]
  82. # paged_kv_indptr is used to index into paged_kv_indices:
  83. # [0, 3, 6, 8]
  84. # The indptr of the paged kv cache, shape: [batch_size + 1]
  85. paged_kv_indptr: Optional[torch.Tensor] = None
  86. # The page indices of the paged kv cache
  87. paged_kv_indices: Optional[torch.Tensor] = None
  88. # The number of entries in the last page of each request in
  89. # the paged kv cache, shape: [batch_size]
  90. paged_kv_last_page_len: Optional[torch.Tensor] = None
  91. # The number of query/output heads
  92. num_qo_heads: Optional[int] = None
  93. # The number of key/value heads
  94. num_kv_heads: Optional[int] = None
  95. # The dimension of the attention heads
  96. head_dim: Optional[int] = None
  97. # Block size of Aphrodite
  98. page_size: Optional[int] = None
  99. # The data type of the paged kv cache
  100. data_type: torch.dtype = None
  101. device: torch.device = torch.device("cuda")
  102. def __post_init__(self):
  103. # Refer to
  104. # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
  105. supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
  106. if self.head_dim is not None and self.head_dim \
  107. not in supported_head_sizes:
  108. raise ValueError(
  109. f"Only {supported_head_sizes} are supported for head_dim,",
  110. f"received {self.head_dim}.")
  111. def begin_forward(self):
  112. if self.num_prefill_tokens > 0:
  113. if self.paged_kv_indices is None:
  114. return
  115. assert self.prefill_wrapper is not None
  116. assert self.query_start_loc is not None
  117. assert self.paged_kv_indices is not None
  118. assert self.paged_kv_indptr is not None
  119. assert self.paged_kv_last_page_len is not None
  120. batch_size = self.query_start_loc.shape[0] - 1
  121. assert batch_size >= 0
  122. # The prefill stage does not read kv cache.
  123. # Both paged_kv_indices and paged_kv_last_page_len are empty.
  124. # paged_kv_indptr is a zero tensor with size batch_size + 1.
  125. self.paged_kv_indptr = torch.zeros(batch_size + 1,
  126. device=self.device)
  127. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  128. self.device)
  129. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  130. self.prefill_wrapper.end_forward()
  131. self.prefill_wrapper.begin_forward(
  132. self.query_start_loc, self.paged_kv_indptr,
  133. self.paged_kv_indices, self.paged_kv_last_page_len,
  134. self.num_qo_heads, self.num_kv_heads, self.head_dim,
  135. self.page_size)
  136. else:
  137. if not self.use_cuda_graph:
  138. assert self.paged_kv_indices is not None
  139. assert self.paged_kv_indptr is not None
  140. assert self.paged_kv_last_page_len is not None
  141. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  142. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  143. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  144. self.device)
  145. assert self.decode_wrapper is not None
  146. self.decode_wrapper.end_forward()
  147. self.decode_wrapper.begin_forward(
  148. self.paged_kv_indptr,
  149. self.paged_kv_indices,
  150. self.paged_kv_last_page_len,
  151. self.num_qo_heads,
  152. self.num_kv_heads,
  153. self.head_dim,
  154. self.page_size,
  155. # Disable flashinfer's pos encoding and use Aphrodite's rope.
  156. pos_encoding_mode="NONE",
  157. data_type=self.data_type)
  158. def asdict_zerocopy(self,
  159. skip_fields: Optional[Set[str]] = None
  160. ) -> Dict[str, Any]:
  161. if skip_fields is None:
  162. skip_fields = set()
  163. # We need to skip the prefill/decode_wrapper field since it cannot be
  164. # broadcasted with nccl when TP is enabled.
  165. skip_fields.add('prefill_wrapper')
  166. skip_fields.add('decode_wrapper')
  167. return super().asdict_zerocopy(skip_fields)
  168. @property
  169. def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
  170. # Currently chunked prefill is not supported
  171. if self.num_decode_tokens == 0:
  172. assert self.num_prefills > 0
  173. return self
  174. return None
  175. @property
  176. def decode_metadata(self) -> Optional["FlashInferMetadata"]:
  177. # Currently chunked prefill is not supported
  178. if self.num_prefills > 0:
  179. assert self.num_decode_tokens == 0
  180. return None
  181. return self
  182. class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
  183. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  184. self.slot_mapping: List[int] = []
  185. self.prefill_seq_lens: List[int] = []
  186. self.context_lens: List[int] = []
  187. self.block_tables: List[List[int]] = []
  188. self.curr_seq_lens: List[int] = []
  189. self.num_prefills = 0
  190. self.num_prefill_tokens = 0
  191. self.num_decode_tokens = 0
  192. self.input_builder = input_builder
  193. self.runner = input_builder.runner
  194. self.sliding_window = input_builder.sliding_window
  195. self.block_size = input_builder.block_size
  196. self.use_v2_block_manager = (
  197. input_builder.scheduler_config.use_v2_block_manager)
  198. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  199. # for the precise definition of the following fields.
  200. # An example:
  201. # request 1, page indices [0, 5, 8]
  202. # request 2, page indices [1, 6, 7]
  203. # request 3, page indices [3, 4]
  204. # paged_kv_indices is a concatenation of page indices of all requests:
  205. # [0, 5, 8, 1, 6, 7, 3, 4]
  206. # paged_kv_indptr is used to index into paged_kv_indices:
  207. # [0, 3, 6, 8]
  208. self.paged_kv_indices: List[int] = []
  209. # 0 at the beginning of paged_kv_indptr indicates the start of the
  210. # first request’s page indices in the paged_kv_indices list.
  211. self.paged_kv_indptr: List[int] = [0]
  212. # paged_kv_last_page_len is the length of the last page of each request
  213. self.paged_kv_last_page_len: List[int] = []
  214. def _add_seq_group(
  215. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  216. chunked_prefill_enabled: bool):
  217. """Add a sequence group to the metadata. Specifically update/append
  218. 1. context length.
  219. 2. block table.
  220. 3. slot mapping.
  221. """
  222. is_prompt = inter_data.is_prompt
  223. block_tables = inter_data.block_tables
  224. computed_block_nums = inter_data.computed_block_nums
  225. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  226. curr_sliding_window_block) in zip(
  227. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  228. inter_data.orig_seq_lens, inter_data.seq_lens,
  229. inter_data.query_lens, inter_data.context_lens,
  230. inter_data.curr_sliding_window_blocks):
  231. self.context_lens.append(context_len)
  232. if is_prompt:
  233. self.num_prefills += 1
  234. self.num_prefill_tokens += token_len
  235. self.prefill_seq_lens.append(seq_len)
  236. else:
  237. assert query_len == 1, (
  238. "seq_len: {}, context_len: {}, query_len: {}".format(
  239. seq_len, context_len, query_len))
  240. self.num_decode_tokens += query_len
  241. self.curr_seq_lens.append(curr_seq_len)
  242. # Compute block table.
  243. # TODO: Combine chunked prefill and prefix caching by
  244. # only allowing multiple of block_size chunk size.
  245. # NOTE: This only works for oooooooxxx style attention.
  246. block_table = []
  247. if inter_data.prefix_cache_hit:
  248. block_table = computed_block_nums
  249. elif ((chunked_prefill_enabled or not is_prompt)
  250. and block_tables is not None):
  251. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  252. self.block_tables.append(block_table)
  253. is_profile_run = is_block_tables_empty(block_tables)
  254. # Compute slot mapping.
  255. start_idx = compute_slot_mapping_start_idx(
  256. is_prompt, query_len, context_len, self.sliding_window,
  257. self.use_v2_block_manager)
  258. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  259. seq_len, context_len, start_idx,
  260. self.block_size, inter_data.block_tables)
  261. # It is not necessary to add paged_kv_indices, paged_kv_indptr,
  262. # and paged_kv_last_page_len for profile run because we will
  263. # create dummy inputs.
  264. if is_profile_run:
  265. return
  266. block_table = block_tables[seq_id]
  267. self._update_paged_kv_tensors(block_table, seq_len)
  268. def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
  269. # Get the number of valid blocks based on sequence length.
  270. # If seq_len = 16, block_size = 16,
  271. # block_table_bound is 1 with 1 valid block.
  272. # If seq_len = 15, block_size = 16,
  273. # block_table_bound is 0 + 1 with 1 valid block.
  274. block_table_bound = seq_len // self.block_size + 1 \
  275. if seq_len % self.block_size != 0 \
  276. else seq_len // self.block_size
  277. self.paged_kv_indices.extend(block_table[:block_table_bound])
  278. self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
  279. block_table_bound)
  280. last_page_len = seq_len % self.block_size
  281. if last_page_len == 0:
  282. last_page_len = self.block_size
  283. self.paged_kv_last_page_len.append(last_page_len)
  284. def build(self, seq_lens: List[int], query_lens: List[int],
  285. cuda_graph_pad_size: int, batch_size: int):
  286. """Build attention metadata with on-device tensors.
  287. Args:
  288. seq_lens: The maybe padded sequence lengths of the input sequences.
  289. query_lens: The query lengths of the input sequences.
  290. cuda_graph_pad_size: The padding size for cuda graph.
  291. -1 if cuda graph is not used.
  292. batch_size: The maybe padded batch size.
  293. """
  294. for inter_data in self.input_builder.inter_data_list:
  295. self._add_seq_group(inter_data,
  296. self.input_builder.chunked_prefill_enabled)
  297. device = self.runner.device
  298. use_captured_graph = cuda_graph_pad_size != -1
  299. max_query_len = max(query_lens)
  300. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  301. num_decode_tokens = self.num_decode_tokens
  302. if use_captured_graph:
  303. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  304. self.block_tables.extend([] * cuda_graph_pad_size)
  305. num_decode_tokens = batch_size
  306. # The shape of graph_block_tables is
  307. # [max batch size, max context len // block size].
  308. input_block_tables = self.runner.graph_block_tables[:batch_size]
  309. for i, block_table in enumerate(self.block_tables):
  310. if block_table:
  311. input_block_tables[i, :len(block_table)] = block_table
  312. block_tables = torch.from_numpy(input_block_tables).to(
  313. device, non_blocking=True)
  314. last_paged_kv_indptr = self.paged_kv_indptr[-1]
  315. self.paged_kv_indptr.extend([last_paged_kv_indptr] *
  316. cuda_graph_pad_size)
  317. self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
  318. else:
  319. block_tables = make_tensor_with_pad(
  320. self.block_tables,
  321. pad=0,
  322. dtype=torch.int,
  323. device=device,
  324. )
  325. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  326. assert device is not None
  327. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  328. self.runner.pin_memory)
  329. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  330. self.runner.pin_memory)
  331. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  332. device, self.runner.pin_memory)
  333. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  334. dtype=torch.int32,
  335. device=device)
  336. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  337. dtype=torch.int32,
  338. device=device)
  339. torch.cumsum(seq_lens_tensor,
  340. dim=0,
  341. dtype=seq_start_loc.dtype,
  342. out=seq_start_loc[1:])
  343. torch.cumsum(query_lens_tensor,
  344. dim=0,
  345. dtype=query_start_loc.dtype,
  346. out=query_start_loc[1:])
  347. if len(self.paged_kv_indptr) > 0:
  348. paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
  349. device="cpu",
  350. dtype=torch.int)
  351. paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
  352. device="cpu",
  353. dtype=torch.int)
  354. paged_kv_last_page_len_tensor = torch.tensor(
  355. self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
  356. else:
  357. paged_kv_indices_tensor = None
  358. paged_kv_indptr_tensor = None
  359. paged_kv_last_page_len_tensor = None
  360. kv_cache_dtype = get_kv_cache_torch_dtype(
  361. self.runner.kv_cache_dtype, self.runner.model_config.dtype)
  362. return FlashInferMetadata(
  363. num_prefills=self.num_prefills,
  364. slot_mapping=slot_mapping_tensor,
  365. num_prefill_tokens=self.num_prefill_tokens,
  366. num_decode_tokens=num_decode_tokens,
  367. max_prefill_seq_len=max_prefill_seq_len,
  368. block_tables=block_tables,
  369. paged_kv_indptr=paged_kv_indptr_tensor,
  370. paged_kv_indices=paged_kv_indices_tensor,
  371. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  372. num_qo_heads=self.runner.model_config.get_num_attention_heads(
  373. self.runner.parallel_config),
  374. num_kv_heads=self.runner.model_config.get_num_kv_heads(
  375. self.runner.parallel_config),
  376. head_dim=self.runner.model_config.get_head_size(),
  377. page_size=self.block_size,
  378. seq_start_loc=seq_start_loc,
  379. query_start_loc=query_start_loc,
  380. device=device,
  381. data_type=kv_cache_dtype,
  382. use_cuda_graph=use_captured_graph)
  383. class FlashInferImpl(AttentionImpl):
  384. def __init__(
  385. self,
  386. num_heads: int,
  387. head_size: int,
  388. scale: float,
  389. num_kv_heads: int,
  390. alibi_slopes: Optional[List[float]],
  391. sliding_window: Optional[int],
  392. kv_cache_dtype: str,
  393. blocksparse_params: Optional[Dict[str, Any]] = None,
  394. logits_soft_cap: Optional[float] = None,
  395. ) -> None:
  396. assert blocksparse_params is None, ValueError(
  397. "FlashInfer does not support block-sparse attention.")
  398. self.num_heads = num_heads
  399. self.head_size = head_size
  400. self.scale = float(scale)
  401. self.num_kv_heads = num_kv_heads
  402. if alibi_slopes is not None:
  403. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  404. self.alibi_slopes = alibi_slopes
  405. if sliding_window is not None:
  406. raise ValueError("Sliding window is not supported in FlashInfer.")
  407. self.sliding_window = (-1, -1)
  408. self.kv_cache_dtype = kv_cache_dtype
  409. self.logits_soft_cap = logits_soft_cap
  410. assert self.num_heads % self.num_kv_heads == 0
  411. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  412. def forward(
  413. self,
  414. query: torch.Tensor,
  415. key: torch.Tensor,
  416. value: torch.Tensor,
  417. kv_cache: Optional[torch.Tensor],
  418. attn_metadata: FlashInferMetadata,
  419. k_scale: float = 1.0,
  420. v_scale: float = 1.0,
  421. attn_type: AttentionType = AttentionType.DECODER,
  422. ) -> torch.Tensor:
  423. assert k_scale == 1.0 and v_scale == 1.0, (
  424. "key/v_scale is not supported in FlashInfer.")
  425. if attn_type != AttentionType.DECODER:
  426. raise NotImplementedError("Encoder self-attention and "
  427. "encoder/decoder cross-attention "
  428. "are not implemented for "
  429. "FlashInferImpl")
  430. num_tokens, hidden_size = query.shape
  431. query = query.view(-1, self.num_heads, self.head_size)
  432. key = key.view(-1, self.num_kv_heads, self.head_size)
  433. value = value.view(-1, self.num_kv_heads, self.head_size)
  434. if attn_metadata.num_prefill_tokens > 0:
  435. assert attn_metadata.num_decode_tokens == 0, (
  436. "Chunked prefill is not supported with flashinfer yet.")
  437. if attn_metadata.num_decode_tokens > 0:
  438. assert attn_metadata.num_prefill_tokens == 0, (
  439. "Chunked prefill is not supported with flashinfer yet.")
  440. if kv_cache is not None:
  441. # Use the same reshape and cache kernel as flash attention.
  442. ops.reshape_and_cache_flash(
  443. key,
  444. value,
  445. kv_cache[:, 0],
  446. kv_cache[:, 1],
  447. attn_metadata.slot_mapping.flatten(),
  448. self.kv_cache_dtype,
  449. k_scale,
  450. v_scale,
  451. )
  452. query = query.contiguous(
  453. ) # Flashinfer requires query to be contiguous
  454. if prefill_meta := attn_metadata.prefill_metadata:
  455. # We will use flash attention for prefill
  456. # when kv_cache is not provided.
  457. # This happens when Aphrodite runs the profiling to
  458. # determine the number of blocks.
  459. if kv_cache is None:
  460. output = flash_attn_varlen_func(
  461. q=query,
  462. k=key,
  463. v=value,
  464. cu_seqlens_q=prefill_meta.seq_start_loc,
  465. cu_seqlens_k=prefill_meta.seq_start_loc,
  466. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  467. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  468. softmax_scale=self.scale,
  469. causal=True,
  470. window_size=self.sliding_window,
  471. alibi_slopes=self.alibi_slopes,
  472. )
  473. else:
  474. assert prefill_meta is not None
  475. assert prefill_meta.prefill_wrapper is not None
  476. output = prefill_meta.prefill_wrapper.forward(
  477. query,
  478. kv_cache,
  479. logits_soft_cap=self.logits_soft_cap,
  480. causal=True)
  481. else:
  482. assert attn_metadata.decode_metadata is not None
  483. assert attn_metadata.decode_metadata.decode_wrapper is not None
  484. output = attn_metadata.decode_metadata.decode_wrapper.forward(
  485. query,
  486. kv_cache,
  487. sm_scale=self.scale,
  488. logits_soft_cap=self.logits_soft_cap)
  489. return output.view(num_tokens, hidden_size)