1
0

flashinfer.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. from contextlib import contextmanager
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
  4. try:
  5. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  6. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  7. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  8. import aphrodite.attention.backends.flash_attn # noqa
  9. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  10. except ImportError:
  11. BatchDecodeWithPagedKVCacheWrapper = None
  12. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  13. BatchPrefillWithPagedKVCacheWrapper = None
  14. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  15. import torch
  16. from aphrodite import _custom_ops as ops
  17. from aphrodite.attention.backends.abstract import (AttentionBackend,
  18. AttentionImpl,
  19. AttentionMetadata,
  20. AttentionMetadataBuilder,
  21. AttentionState,
  22. AttentionType)
  23. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  24. compute_slot_mapping,
  25. compute_slot_mapping_start_idx,
  26. is_block_tables_empty)
  27. from aphrodite.attention.ops.paged_attn import PagedAttention
  28. from aphrodite.common.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
  29. make_tensor_with_pad)
  30. if TYPE_CHECKING:
  31. from aphrodite.worker.model_runner import ModelInputForGPUBuilder
  32. class FlashInferBackend(AttentionBackend):
  33. @staticmethod
  34. def get_name() -> str:
  35. return "flashinfer"
  36. @staticmethod
  37. def get_impl_cls() -> Type["FlashInferImpl"]:
  38. return FlashInferImpl
  39. @staticmethod
  40. def get_metadata_cls() -> Type["AttentionMetadata"]:
  41. return FlashInferMetadata
  42. @staticmethod
  43. def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
  44. return FlashInferMetadataBuilder
  45. @staticmethod
  46. def get_state_cls() -> Type["FlashInferState"]:
  47. return FlashInferState
  48. @staticmethod
  49. def get_kv_cache_shape(
  50. num_blocks: int,
  51. block_size: int,
  52. num_kv_heads: int,
  53. head_size: int,
  54. ) -> Tuple[int, ...]:
  55. return (num_blocks, 2, block_size, num_kv_heads, head_size)
  56. @staticmethod
  57. def swap_blocks(
  58. src_kv_cache: torch.Tensor,
  59. dst_kv_cache: torch.Tensor,
  60. src_to_dst: torch.Tensor,
  61. ) -> None:
  62. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  63. @staticmethod
  64. def copy_blocks(
  65. kv_caches: List[torch.Tensor],
  66. src_to_dists: torch.Tensor,
  67. ) -> None:
  68. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  69. @staticmethod
  70. def get_supported_head_sizes() -> List[int]:
  71. return [64, 128, 256]
  72. @staticmethod
  73. def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
  74. if kv_cache_dtype in ("fp8", "fp8_e4m3"):
  75. return torch.float8_e4m3fn
  76. elif kv_cache_dtype == "fp8_e5m2":
  77. return torch.float8_e5m2
  78. else:
  79. raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
  80. class FlashInferState(AttentionState):
  81. def __init__(self, runner):
  82. self.runner = runner
  83. self._is_graph_capturing = False
  84. self._workspace_buffer = None
  85. self._decode_wrapper = None
  86. self._prefill_wrapper = None
  87. def _get_workspace_buffer(self):
  88. if self._workspace_buffer is None:
  89. self._workspace_buffer = torch.empty(
  90. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  91. dtype=torch.uint8,
  92. device=self.runner.device)
  93. return self._workspace_buffer
  94. def _get_prefill_wrapper(self):
  95. if self._prefill_wrapper is None:
  96. self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
  97. self._get_workspace_buffer(), "NHD")
  98. return self._prefill_wrapper
  99. def _get_decode_wrapper(self):
  100. if self._decode_wrapper is None:
  101. num_qo_heads = (self.runner.model_config.get_num_attention_heads(
  102. self.runner.parallel_config))
  103. num_kv_heads = self.runner.model_config.get_num_kv_heads(
  104. self.runner.parallel_config)
  105. use_tensor_cores = num_qo_heads // num_kv_heads >= 4
  106. self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
  107. self._get_workspace_buffer(),
  108. "NHD",
  109. use_tensor_cores=use_tensor_cores)
  110. return self._decode_wrapper
  111. @contextmanager
  112. def graph_capture(self, max_batch_size: int):
  113. self._is_graph_capturing = True
  114. self._graph_decode_wrapper = None
  115. self._graph_slot_mapping = torch.full((max_batch_size, ),
  116. PAD_SLOT_ID,
  117. dtype=torch.long,
  118. device=self.runner.device)
  119. self._graph_seq_lens = torch.ones(max_batch_size,
  120. dtype=torch.int32,
  121. device=self.runner.device)
  122. self._graph_block_tables = torch.from_numpy(
  123. self.runner.graph_block_tables).to(device=self.runner.device)
  124. self._graph_decode_workspace_buffer = self._get_workspace_buffer()
  125. self._graph_indices_buffer = torch.empty(
  126. max_batch_size * self.runner.cache_config.num_gpu_blocks,
  127. dtype=torch.int32,
  128. device=self.runner.device)
  129. self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
  130. dtype=torch.int32,
  131. device=self.runner.device)
  132. self._graph_last_page_len_buffer = torch.empty(
  133. max_batch_size, dtype=torch.int32, device=self.runner.device)
  134. yield
  135. self._is_graph_capturing = False
  136. del self._graph_slot_mapping
  137. del self._graph_seq_lens
  138. del self._graph_block_tables
  139. del self._graph_decode_workspace_buffer
  140. del self._graph_indices_buffer
  141. del self._graph_indptr_buffer
  142. del self._graph_last_page_len_buffer
  143. del self._graph_decode_wrapper
  144. def graph_clone(self, batch_size: int):
  145. assert self._is_graph_capturing
  146. state = self.__class__(self.runner)
  147. state._workspace_buffer = self._graph_decode_workspace_buffer
  148. state._decode_wrapper = self._graph_decode_wrapper
  149. state._prefill_wrapper = self._get_prefill_wrapper()
  150. return state
  151. def graph_capture_get_metadata_for_batch(self, batch_size: int):
  152. assert self._is_graph_capturing
  153. _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
  154. _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
  155. num_qo_heads = (self.runner.model_config.get_num_attention_heads(
  156. self.runner.parallel_config))
  157. num_kv_heads = self.runner.model_config.get_num_kv_heads(
  158. self.runner.parallel_config)
  159. use_tensor_cores = num_qo_heads // num_kv_heads >= 4
  160. self._graph_decode_wrapper = \
  161. CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
  162. self._graph_decode_workspace_buffer, _indptr_buffer,
  163. self._graph_indices_buffer, _last_page_len_buffer, "NHD",
  164. use_tensor_cores)
  165. kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
  166. self.runner.kv_cache_dtype)
  167. paged_kv_indptr_tensor_host = torch.arange(0,
  168. batch_size + 1,
  169. dtype=torch.int32)
  170. paged_kv_indices_tensor_host = torch.arange(0,
  171. batch_size,
  172. dtype=torch.int32)
  173. paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
  174. self.runner.block_size,
  175. dtype=torch.int32)
  176. query_start_loc_host = torch.arange(0,
  177. batch_size + 1,
  178. dtype=torch.int32)
  179. attn_metadata = self.runner.attn_backend.make_metadata(
  180. num_prefills=0,
  181. slot_mapping=self._graph_slot_mapping[:batch_size],
  182. num_prefill_tokens=0,
  183. num_decode_tokens=batch_size,
  184. max_prefill_seq_len=0,
  185. block_tables=self._graph_block_tables,
  186. paged_kv_indptr=paged_kv_indptr_tensor_host,
  187. paged_kv_indices=paged_kv_indices_tensor_host,
  188. paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
  189. num_qo_heads=num_qo_heads,
  190. num_kv_heads=num_kv_heads,
  191. head_dim=self.runner.model_config.get_head_size(),
  192. page_size=self.runner.block_size,
  193. seq_start_loc=None,
  194. query_start_loc=query_start_loc_host,
  195. device=self.runner.device,
  196. data_type=kv_cache_dtype,
  197. use_cuda_graph=True,
  198. decode_wrapper=self._graph_decode_wrapper,
  199. prefill_wrapper=None)
  200. attn_metadata.begin_forward()
  201. return attn_metadata
  202. def get_graph_input_buffers(self, attn_metadata):
  203. return {
  204. "slot_mapping": attn_metadata.slot_mapping,
  205. }
  206. def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
  207. return
  208. def begin_forward(self, model_input):
  209. assert not self._is_graph_capturing
  210. state = self
  211. if model_input.attn_metadata.use_cuda_graph:
  212. batch_size = model_input.input_tokens.shape[0]
  213. state = (self.runner.graph_runners[model_input.virtual_engine]
  214. [batch_size].attn_state)
  215. model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
  216. )
  217. model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
  218. model_input.attn_metadata.begin_forward()
  219. @dataclass
  220. class FlashInferMetadata(AttentionMetadata):
  221. # Maximum sequence length among prefill batch. 0 if there are decoding
  222. # requests only.
  223. max_prefill_seq_len: int
  224. use_cuda_graph: bool = True
  225. prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
  226. decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
  227. # Metadata for the prefill stage
  228. seq_start_loc: Optional[torch.Tensor] = None
  229. query_start_loc: Optional[torch.Tensor] = None
  230. block_tables: Optional[torch.Tensor] = None
  231. # An example for paged_kv_indices, paged_kv_indptr:
  232. # request 1, page indices [0, 5, 8]
  233. # request 2, page indices [1, 6, 7]
  234. # request 3, page indices [3, 4]
  235. # paged_kv_indices is a concatenation of page indices of all requests:
  236. # [0, 5, 8, 1, 6, 7, 3, 4]
  237. # paged_kv_indptr is used to index into paged_kv_indices:
  238. # [0, 3, 6, 8]
  239. # The indptr of the paged kv cache, shape: [batch_size + 1]
  240. paged_kv_indptr: Optional[torch.Tensor] = None
  241. # The page indices of the paged kv cache
  242. paged_kv_indices: Optional[torch.Tensor] = None
  243. # The number of entries in the last page of each request in
  244. # the paged kv cache, shape: [batch_size]
  245. paged_kv_last_page_len: Optional[torch.Tensor] = None
  246. # The number of query/output heads
  247. num_qo_heads: Optional[int] = None
  248. # The number of key/value heads
  249. num_kv_heads: Optional[int] = None
  250. # The dimension of the attention heads
  251. head_dim: Optional[int] = None
  252. # Block size of Aphrodite
  253. page_size: Optional[int] = None
  254. # The data type of the paged kv cache
  255. data_type: torch.dtype = None
  256. device: torch.device = torch.device("cuda")
  257. is_profile_run: bool = False
  258. def __post_init__(self):
  259. # Refer to
  260. # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
  261. supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
  262. if self.head_dim is not None and self.head_dim \
  263. not in supported_head_sizes:
  264. raise ValueError(
  265. f"Only {supported_head_sizes} are supported for head_dim,",
  266. f"received {self.head_dim}.")
  267. def begin_forward(self):
  268. if self.num_prefill_tokens > 0:
  269. if self.paged_kv_indices is None:
  270. return
  271. assert self.prefill_wrapper is not None
  272. assert self.query_start_loc is not None
  273. assert self.paged_kv_indices is not None
  274. assert self.paged_kv_indptr is not None
  275. assert self.paged_kv_last_page_len is not None
  276. batch_size = self.query_start_loc.shape[0] - 1
  277. assert batch_size >= 0
  278. # We will use flash attention for profiling to
  279. # determine the number of blocks. Therefore,
  280. # we don't need to prepare the input for flashinfer for profile run.
  281. if not self.is_profile_run:
  282. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  283. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  284. self.device)
  285. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  286. self.prefill_wrapper.end_forward()
  287. self.prefill_wrapper.begin_forward(
  288. self.query_start_loc, self.paged_kv_indptr,
  289. self.paged_kv_indices, self.paged_kv_last_page_len,
  290. self.num_qo_heads, self.num_kv_heads, self.head_dim,
  291. self.page_size)
  292. else:
  293. if not self.use_cuda_graph:
  294. assert self.paged_kv_indices is not None
  295. assert self.paged_kv_indptr is not None
  296. assert self.paged_kv_last_page_len is not None
  297. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  298. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  299. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  300. self.device)
  301. assert self.decode_wrapper is not None
  302. self.decode_wrapper.end_forward()
  303. self.decode_wrapper.begin_forward(
  304. self.paged_kv_indptr,
  305. self.paged_kv_indices,
  306. self.paged_kv_last_page_len,
  307. self.num_qo_heads,
  308. self.num_kv_heads,
  309. self.head_dim,
  310. self.page_size,
  311. # Disable flashinfer's pos encoding and use Aphrodite's rope.
  312. pos_encoding_mode="NONE",
  313. )
  314. def asdict_zerocopy(self,
  315. skip_fields: Optional[Set[str]] = None
  316. ) -> Dict[str, Any]:
  317. if skip_fields is None:
  318. skip_fields = set()
  319. # We need to skip the prefill/decode_wrapper field since it cannot be
  320. # broadcasted with nccl when TP is enabled.
  321. skip_fields.add('prefill_wrapper')
  322. skip_fields.add('decode_wrapper')
  323. return super().asdict_zerocopy(skip_fields)
  324. @property
  325. def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
  326. # Currently chunked prefill is not supported
  327. if self.num_decode_tokens == 0:
  328. assert self.num_prefills > 0
  329. return self
  330. return None
  331. @property
  332. def decode_metadata(self) -> Optional["FlashInferMetadata"]:
  333. # Currently chunked prefill is not supported
  334. if self.num_prefills > 0:
  335. assert self.num_decode_tokens == 0, (
  336. "Chunked prefill is not supported with flashinfer yet.")
  337. return None
  338. return self
  339. class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
  340. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  341. self.slot_mapping: List[int] = []
  342. self.prefill_seq_lens: List[int] = []
  343. self.context_lens: List[int] = []
  344. self.block_tables: List[List[int]] = []
  345. self.curr_seq_lens: List[int] = []
  346. self.num_prefills = 0
  347. self.num_prefill_tokens = 0
  348. self.num_decode_tokens = 0
  349. self.input_builder = input_builder
  350. self.runner = input_builder.runner
  351. self.sliding_window = input_builder.sliding_window
  352. self.block_size = input_builder.block_size
  353. self.use_v2_block_manager = (
  354. input_builder.scheduler_config.use_v2_block_manager)
  355. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  356. # for the precise definition of the following fields.
  357. # An example:
  358. # request 1, page indices [0, 5, 8]
  359. # request 2, page indices [1, 6, 7]
  360. # request 3, page indices [3, 4]
  361. # paged_kv_indices is a concatenation of page indices of all requests:
  362. # [0, 5, 8, 1, 6, 7, 3, 4]
  363. # paged_kv_indptr is used to index into paged_kv_indices:
  364. # [0, 3, 6, 8]
  365. self.paged_kv_indices: List[int] = []
  366. # 0 at the beginning of paged_kv_indptr indicates the start of the
  367. # first request’s page indices in the paged_kv_indices list.
  368. self.paged_kv_indptr: List[int] = [0]
  369. # paged_kv_last_page_len is the length of the last page of each request
  370. self.paged_kv_last_page_len: List[int] = []
  371. self.is_profile_run: bool = False
  372. def _add_seq_group(
  373. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  374. chunked_prefill_enabled: bool):
  375. """Add a sequence group to the metadata. Specifically update/append
  376. 1. context length.
  377. 2. block table.
  378. 3. slot mapping.
  379. """
  380. is_prompt = inter_data.is_prompt
  381. block_tables = inter_data.block_tables
  382. computed_block_nums = inter_data.computed_block_nums
  383. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  384. curr_sliding_window_block) in zip(
  385. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  386. inter_data.orig_seq_lens, inter_data.seq_lens,
  387. inter_data.query_lens, inter_data.context_lens,
  388. inter_data.curr_sliding_window_blocks):
  389. self.context_lens.append(context_len)
  390. if is_prompt:
  391. self.num_prefills += 1
  392. self.num_prefill_tokens += token_len
  393. self.prefill_seq_lens.append(seq_len)
  394. else:
  395. assert query_len == 1, (
  396. "seq_len: {}, context_len: {}, query_len: {}".format(
  397. seq_len, context_len, query_len))
  398. self.num_decode_tokens += query_len
  399. self.curr_seq_lens.append(curr_seq_len)
  400. # Compute block table.
  401. # TODO: Combine chunked prefill and prefix caching by
  402. # only allowing multiple of block_size chunk size.
  403. # NOTE: This only works for oooooooxxx style attention.
  404. block_table = []
  405. if inter_data.prefix_cache_hit:
  406. block_table = computed_block_nums
  407. elif ((chunked_prefill_enabled or not is_prompt)
  408. and block_tables is not None):
  409. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  410. self.block_tables.append(block_table)
  411. is_profile_run = is_block_tables_empty(block_tables)
  412. # Compute slot mapping.
  413. start_idx = compute_slot_mapping_start_idx(
  414. is_prompt, query_len, context_len, self.sliding_window,
  415. self.use_v2_block_manager)
  416. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  417. seq_len, context_len, start_idx,
  418. self.block_size, inter_data.block_tables)
  419. # It is not necessary to add paged_kv_indices, paged_kv_indptr,
  420. # and paged_kv_last_page_len for profile run because we will
  421. # create dummy inputs.
  422. if is_profile_run:
  423. self.is_profile_run = is_profile_run
  424. return
  425. block_table = block_tables[seq_id]
  426. self._update_paged_kv_tensors(block_table, seq_len)
  427. def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
  428. # Get the number of valid blocks based on sequence length.
  429. # If seq_len = 16, block_size = 16,
  430. # block_table_bound is 1 with 1 valid block.
  431. # If seq_len = 15, block_size = 16,
  432. # block_table_bound is 0 + 1 with 1 valid block.
  433. block_table_bound = seq_len // self.block_size + 1 \
  434. if seq_len % self.block_size != 0 \
  435. else seq_len // self.block_size
  436. self.paged_kv_indices.extend(block_table[:block_table_bound])
  437. self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
  438. block_table_bound)
  439. last_page_len = seq_len % self.block_size
  440. if last_page_len == 0:
  441. last_page_len = self.block_size
  442. self.paged_kv_last_page_len.append(last_page_len)
  443. def build(self, seq_lens: List[int], query_lens: List[int],
  444. cuda_graph_pad_size: int, batch_size: int):
  445. """Build attention metadata with on-device tensors.
  446. Args:
  447. seq_lens: The maybe padded sequence lengths of the input sequences.
  448. query_lens: The query lengths of the input sequences.
  449. cuda_graph_pad_size: The padding size for cuda graph.
  450. -1 if cuda graph is not used.
  451. batch_size: The maybe padded batch size.
  452. """
  453. for inter_data in self.input_builder.inter_data_list:
  454. self._add_seq_group(inter_data,
  455. self.input_builder.chunked_prefill_enabled)
  456. device = self.runner.device
  457. use_captured_graph = cuda_graph_pad_size != -1
  458. max_query_len = max(query_lens)
  459. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  460. num_decode_tokens = self.num_decode_tokens
  461. if use_captured_graph:
  462. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  463. self.block_tables.extend([] * cuda_graph_pad_size)
  464. num_decode_tokens = batch_size
  465. # The shape of graph_block_tables is
  466. # [max batch size, max context len // block size].
  467. input_block_tables = self.runner.graph_block_tables[:batch_size]
  468. for i, block_table in enumerate(self.block_tables):
  469. if block_table:
  470. input_block_tables[i, :len(block_table)] = block_table
  471. block_tables = torch.from_numpy(input_block_tables).to(
  472. device, non_blocking=True)
  473. last_paged_kv_indptr = self.paged_kv_indptr[-1]
  474. self.paged_kv_indptr.extend([last_paged_kv_indptr] *
  475. cuda_graph_pad_size)
  476. self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
  477. else:
  478. block_tables = make_tensor_with_pad(
  479. self.block_tables,
  480. pad=0,
  481. dtype=torch.int,
  482. device=device,
  483. )
  484. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  485. assert device is not None
  486. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  487. self.runner.pin_memory)
  488. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  489. self.runner.pin_memory)
  490. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  491. device, self.runner.pin_memory)
  492. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  493. dtype=torch.int32,
  494. device=device)
  495. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  496. dtype=torch.int32,
  497. device=device)
  498. torch.cumsum(seq_lens_tensor,
  499. dim=0,
  500. dtype=seq_start_loc.dtype,
  501. out=seq_start_loc[1:])
  502. torch.cumsum(query_lens_tensor,
  503. dim=0,
  504. dtype=query_start_loc.dtype,
  505. out=query_start_loc[1:])
  506. if len(self.paged_kv_indptr) > 0:
  507. paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
  508. device="cpu",
  509. dtype=torch.int)
  510. paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
  511. device="cpu",
  512. dtype=torch.int)
  513. paged_kv_last_page_len_tensor = torch.tensor(
  514. self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
  515. else:
  516. paged_kv_indices_tensor = None
  517. paged_kv_indptr_tensor = None
  518. paged_kv_last_page_len_tensor = None
  519. kv_cache_dtype = get_kv_cache_torch_dtype(
  520. self.runner.kv_cache_dtype, self.runner.model_config.dtype)
  521. return FlashInferMetadata(
  522. num_prefills=self.num_prefills,
  523. slot_mapping=slot_mapping_tensor,
  524. num_prefill_tokens=self.num_prefill_tokens,
  525. num_decode_tokens=num_decode_tokens,
  526. max_prefill_seq_len=max_prefill_seq_len,
  527. block_tables=block_tables,
  528. paged_kv_indptr=paged_kv_indptr_tensor,
  529. paged_kv_indices=paged_kv_indices_tensor,
  530. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  531. num_qo_heads=self.runner.model_config.get_num_attention_heads(
  532. self.runner.parallel_config),
  533. num_kv_heads=self.runner.model_config.get_num_kv_heads(
  534. self.runner.parallel_config),
  535. head_dim=self.runner.model_config.get_head_size(),
  536. page_size=self.block_size,
  537. seq_start_loc=seq_start_loc,
  538. query_start_loc=query_start_loc,
  539. device=device,
  540. data_type=kv_cache_dtype,
  541. use_cuda_graph=use_captured_graph,
  542. is_profile_run=self.is_profile_run)
  543. class FlashInferImpl(AttentionImpl):
  544. def __init__(
  545. self,
  546. num_heads: int,
  547. head_size: int,
  548. scale: float,
  549. num_kv_heads: int,
  550. alibi_slopes: Optional[List[float]],
  551. sliding_window: Optional[int],
  552. kv_cache_dtype: str,
  553. blocksparse_params: Optional[Dict[str, Any]] = None,
  554. logits_soft_cap: Optional[float] = None,
  555. ) -> None:
  556. assert blocksparse_params is None, ValueError(
  557. "FlashInfer does not support block-sparse attention.")
  558. self.num_heads = num_heads
  559. self.head_size = head_size
  560. self.scale = float(scale)
  561. self.num_kv_heads = num_kv_heads
  562. if alibi_slopes is not None:
  563. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  564. self.alibi_slopes = alibi_slopes
  565. if sliding_window is not None:
  566. raise ValueError("Sliding window is not supported in FlashInfer.")
  567. self.sliding_window = (-1, -1)
  568. self.kv_cache_dtype = kv_cache_dtype
  569. self.logits_soft_cap = logits_soft_cap
  570. assert self.num_heads % self.num_kv_heads == 0
  571. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  572. def forward(
  573. self,
  574. query: torch.Tensor,
  575. key: torch.Tensor,
  576. value: torch.Tensor,
  577. kv_cache: Optional[torch.Tensor],
  578. attn_metadata: FlashInferMetadata,
  579. k_scale: float = 1.0,
  580. v_scale: float = 1.0,
  581. attn_type: AttentionType = AttentionType.DECODER,
  582. ) -> torch.Tensor:
  583. assert k_scale == 1.0 and v_scale == 1.0, (
  584. "key/v_scale is not supported in FlashInfer.")
  585. if attn_type != AttentionType.DECODER:
  586. raise NotImplementedError("Encoder self-attention and "
  587. "encoder/decoder cross-attention "
  588. "are not implemented for "
  589. "FlashInferImpl")
  590. num_tokens, hidden_size = query.shape
  591. query = query.view(-1, self.num_heads, self.head_size)
  592. key = key.view(-1, self.num_kv_heads, self.head_size)
  593. value = value.view(-1, self.num_kv_heads, self.head_size)
  594. if attn_metadata.num_prefill_tokens > 0:
  595. assert attn_metadata.num_decode_tokens == 0, (
  596. "Chunked prefill is not supported with flashinfer yet.")
  597. if attn_metadata.num_decode_tokens > 0:
  598. assert attn_metadata.num_prefill_tokens == 0, (
  599. "Chunked prefill is not supported with flashinfer yet.")
  600. if kv_cache is not None:
  601. # Use the same reshape and cache kernel as flash attention.
  602. ops.reshape_and_cache_flash(
  603. key,
  604. value,
  605. kv_cache[:, 0],
  606. kv_cache[:, 1],
  607. attn_metadata.slot_mapping.flatten(),
  608. self.kv_cache_dtype,
  609. k_scale,
  610. v_scale,
  611. )
  612. # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
  613. # to process the cache in fp8
  614. torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
  615. self.kv_cache_dtype)
  616. kv_cache = kv_cache.view(torch_dtype)
  617. query = query.contiguous(
  618. ) # Flashinfer requires query to be contiguous
  619. if prefill_meta := attn_metadata.prefill_metadata:
  620. # We will use flash attention for prefill
  621. # when kv_cache is not provided.
  622. # This happens when Aphrodite runs the profiling to
  623. # determine the number of blocks.
  624. if kv_cache is None:
  625. output = torch.ops.aphrodite.flash_attn_varlen_func(
  626. q=query,
  627. k=key,
  628. v=value,
  629. cu_seqlens_q=prefill_meta.seq_start_loc,
  630. cu_seqlens_k=prefill_meta.seq_start_loc,
  631. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  632. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  633. softmax_scale=self.scale,
  634. causal=True,
  635. window_size=self.sliding_window,
  636. alibi_slopes=self.alibi_slopes,
  637. )
  638. else:
  639. assert prefill_meta is not None
  640. assert prefill_meta.prefill_wrapper is not None
  641. output = prefill_meta.prefill_wrapper.forward(
  642. query,
  643. kv_cache,
  644. logits_soft_cap=self.logits_soft_cap,
  645. causal=True)
  646. else:
  647. assert attn_metadata.decode_metadata is not None
  648. assert attn_metadata.decode_metadata.decode_wrapper is not None
  649. output = attn_metadata.decode_metadata.decode_wrapper.forward(
  650. query,
  651. kv_cache,
  652. sm_scale=self.scale,
  653. logits_soft_cap=self.logits_soft_cap,
  654. k_scale=k_scale,
  655. v_scale=v_scale)
  656. return output.view(num_tokens, hidden_size)