1
0

flashinfer.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818
  1. from contextlib import contextmanager
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
  4. try:
  5. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  6. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  7. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  8. import aphrodite.attention.backends.flash_attn # noqa
  9. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  10. except ImportError:
  11. BatchDecodeWithPagedKVCacheWrapper = None
  12. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  13. BatchPrefillWithPagedKVCacheWrapper = None
  14. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  15. import torch
  16. from aphrodite import _custom_ops as ops
  17. from aphrodite.attention.backends.abstract import (AttentionBackend,
  18. AttentionImpl,
  19. AttentionMetadata,
  20. AttentionMetadataBuilder,
  21. AttentionState,
  22. AttentionType)
  23. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  24. compute_slot_mapping,
  25. compute_slot_mapping_start_idx,
  26. is_block_tables_empty)
  27. from aphrodite.attention.ops.paged_attn import PagedAttention
  28. from aphrodite.common.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
  29. make_tensor_with_pad)
  30. if TYPE_CHECKING:
  31. from aphrodite.worker.model_runner import (
  32. ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata)
  33. class FlashInferBackend(AttentionBackend):
  34. @staticmethod
  35. def get_name() -> str:
  36. return "flashinfer"
  37. @staticmethod
  38. def get_impl_cls() -> Type["FlashInferImpl"]:
  39. return FlashInferImpl
  40. @staticmethod
  41. def get_metadata_cls() -> Type["AttentionMetadata"]:
  42. return FlashInferMetadata
  43. @staticmethod
  44. def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
  45. return FlashInferMetadataBuilder
  46. @staticmethod
  47. def get_state_cls() -> Type["FlashInferState"]:
  48. return FlashInferState
  49. @staticmethod
  50. def get_kv_cache_shape(
  51. num_blocks: int,
  52. block_size: int,
  53. num_kv_heads: int,
  54. head_size: int,
  55. ) -> Tuple[int, ...]:
  56. return (num_blocks, 2, block_size, num_kv_heads, head_size)
  57. @staticmethod
  58. def swap_blocks(
  59. src_kv_cache: torch.Tensor,
  60. dst_kv_cache: torch.Tensor,
  61. src_to_dst: torch.Tensor,
  62. ) -> None:
  63. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  64. @staticmethod
  65. def copy_blocks(
  66. kv_caches: List[torch.Tensor],
  67. src_to_dists: torch.Tensor,
  68. ) -> None:
  69. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  70. @staticmethod
  71. def get_supported_head_sizes() -> List[int]:
  72. return [64, 128, 256]
  73. @staticmethod
  74. def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
  75. if kv_cache_dtype in ("fp8", "fp8_e4m3"):
  76. return torch.float8_e4m3fn
  77. elif kv_cache_dtype == "fp8_e5m2":
  78. return torch.float8_e5m2
  79. else:
  80. raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
  81. class FlashInferState(AttentionState):
  82. def __init__(self, runner):
  83. self.runner = runner
  84. self._is_graph_capturing = False
  85. self._workspace_buffer = None
  86. self._decode_wrapper = None
  87. self._prefill_wrapper = None
  88. def _get_workspace_buffer(self):
  89. if self._workspace_buffer is None:
  90. self._workspace_buffer = torch.empty(
  91. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  92. dtype=torch.uint8,
  93. device=self.runner.device)
  94. return self._workspace_buffer
  95. def _get_prefill_wrapper(self):
  96. if self._prefill_wrapper is None:
  97. self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
  98. self._get_workspace_buffer(), "NHD")
  99. return self._prefill_wrapper
  100. def _get_decode_wrapper(self):
  101. if self._decode_wrapper is None:
  102. num_qo_heads = (self.runner.model_config.get_num_attention_heads(
  103. self.runner.parallel_config))
  104. num_kv_heads = self.runner.model_config.get_num_kv_heads(
  105. self.runner.parallel_config)
  106. use_tensor_cores = num_qo_heads // num_kv_heads > 4
  107. self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
  108. self._get_workspace_buffer(),
  109. "NHD",
  110. use_tensor_cores=use_tensor_cores)
  111. return self._decode_wrapper
  112. @contextmanager
  113. def graph_capture(self, max_batch_size: int):
  114. self._is_graph_capturing = True
  115. self._graph_decode_wrapper = None
  116. self._graph_slot_mapping = torch.full((max_batch_size, ),
  117. PAD_SLOT_ID,
  118. dtype=torch.long,
  119. device=self.runner.device)
  120. self._graph_seq_lens = torch.ones(max_batch_size,
  121. dtype=torch.int32,
  122. device=self.runner.device)
  123. self._graph_block_tables = torch.from_numpy(
  124. self.runner.graph_block_tables).to(device=self.runner.device)
  125. self._graph_decode_workspace_buffer = self._get_workspace_buffer()
  126. self._graph_indices_buffer = torch.empty(
  127. max_batch_size * self.runner.cache_config.num_gpu_blocks,
  128. dtype=torch.int32,
  129. device=self.runner.device)
  130. self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
  131. dtype=torch.int32,
  132. device=self.runner.device)
  133. self._graph_last_page_len_buffer = torch.empty(
  134. max_batch_size, dtype=torch.int32, device=self.runner.device)
  135. yield
  136. self._is_graph_capturing = False
  137. del self._graph_slot_mapping
  138. del self._graph_seq_lens
  139. del self._graph_block_tables
  140. del self._graph_decode_workspace_buffer
  141. del self._graph_indices_buffer
  142. del self._graph_indptr_buffer
  143. del self._graph_last_page_len_buffer
  144. del self._graph_decode_wrapper
  145. def graph_clone(self, batch_size: int):
  146. assert self._is_graph_capturing
  147. state = self.__class__(self.runner)
  148. state._workspace_buffer = self._graph_decode_workspace_buffer
  149. state._decode_wrapper = self._graph_decode_wrapper
  150. state._prefill_wrapper = self._get_prefill_wrapper()
  151. return state
  152. def graph_capture_get_metadata_for_batch(self, batch_size: int):
  153. assert self._is_graph_capturing
  154. _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
  155. _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
  156. num_qo_heads = (self.runner.model_config.get_num_attention_heads(
  157. self.runner.parallel_config))
  158. num_kv_heads = self.runner.model_config.get_num_kv_heads(
  159. self.runner.parallel_config)
  160. use_tensor_cores = num_qo_heads // num_kv_heads > 4
  161. self._graph_decode_wrapper = \
  162. CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
  163. self._graph_decode_workspace_buffer, _indptr_buffer,
  164. self._graph_indices_buffer, _last_page_len_buffer, "NHD",
  165. use_tensor_cores)
  166. if self.runner.kv_cache_dtype.startswith("fp8"):
  167. kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
  168. self.runner.kv_cache_dtype)
  169. else:
  170. kv_cache_dtype = get_kv_cache_torch_dtype(
  171. self.runner.kv_cache_dtype, self.runner.model_config.dtype)
  172. paged_kv_indptr_tensor_host = torch.arange(0,
  173. batch_size + 1,
  174. dtype=torch.int32)
  175. paged_kv_indices_tensor_host = torch.arange(0,
  176. batch_size,
  177. dtype=torch.int32)
  178. paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
  179. self.runner.block_size,
  180. dtype=torch.int32)
  181. query_start_loc_host = torch.arange(0,
  182. batch_size + 1,
  183. dtype=torch.int32)
  184. attn_metadata = self.runner.attn_backend.make_metadata(
  185. num_prefills=0,
  186. slot_mapping=self._graph_slot_mapping[:batch_size],
  187. num_prefill_tokens=0,
  188. num_decode_tokens=batch_size,
  189. max_prefill_seq_len=0,
  190. block_tables=self._graph_block_tables,
  191. paged_kv_indptr=paged_kv_indptr_tensor_host,
  192. paged_kv_indices=paged_kv_indices_tensor_host,
  193. paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
  194. num_qo_heads=num_qo_heads,
  195. num_kv_heads=num_kv_heads,
  196. head_dim=self.runner.model_config.get_head_size(),
  197. page_size=self.runner.block_size,
  198. seq_start_loc=None,
  199. query_start_loc=query_start_loc_host,
  200. device=self.runner.device,
  201. data_type=kv_cache_dtype,
  202. q_data_type=self.runner.model_config.dtype,
  203. use_cuda_graph=True,
  204. decode_wrapper=self._graph_decode_wrapper,
  205. prefill_wrapper=None)
  206. attn_metadata.begin_forward()
  207. return attn_metadata
  208. def get_graph_input_buffers(self, attn_metadata):
  209. return {
  210. "slot_mapping": attn_metadata.slot_mapping,
  211. }
  212. def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
  213. return
  214. def begin_forward(self, model_input):
  215. assert not self._is_graph_capturing
  216. state = self
  217. if model_input.attn_metadata.use_cuda_graph:
  218. batch_size = model_input.input_tokens.shape[0]
  219. state = (self.runner.graph_runners[model_input.virtual_engine]
  220. [batch_size].attn_state)
  221. model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
  222. )
  223. model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
  224. model_input.attn_metadata.begin_forward()
  225. @dataclass
  226. class FlashInferMetadata(AttentionMetadata):
  227. # Maximum sequence length among prefill batch. 0 if there are decoding
  228. # requests only.
  229. max_prefill_seq_len: int
  230. use_cuda_graph: bool = True
  231. prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
  232. decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
  233. # Metadata for the prefill stage
  234. seq_start_loc: Optional[torch.Tensor] = None
  235. query_start_loc: Optional[torch.Tensor] = None
  236. block_tables: Optional[torch.Tensor] = None
  237. # used for GPU in-place advance_step
  238. seq_lens_tensor: Optional[torch.Tensor] = None
  239. block_table_bound: Optional[torch.Tensor] = None
  240. # An example for paged_kv_indices, paged_kv_indptr:
  241. # request 1, page indices [0, 5, 8]
  242. # request 2, page indices [1, 6, 7]
  243. # request 3, page indices [3, 4]
  244. # paged_kv_indices is a concatenation of page indices of all requests:
  245. # [0, 5, 8, 1, 6, 7, 3, 4]
  246. # paged_kv_indptr is used to index into paged_kv_indices:
  247. # [0, 3, 6, 8]
  248. # The indptr of the paged kv cache, shape: [batch_size + 1]
  249. paged_kv_indptr: Optional[torch.Tensor] = None
  250. # The page indices of the paged kv cache
  251. paged_kv_indices: Optional[torch.Tensor] = None
  252. # The number of entries in the last page of each request in
  253. # the paged kv cache, shape: [batch_size]
  254. paged_kv_last_page_len: Optional[torch.Tensor] = None
  255. # The number of query/output heads
  256. num_qo_heads: Optional[int] = None
  257. # The number of key/value heads
  258. num_kv_heads: Optional[int] = None
  259. # The dimension of the attention heads
  260. head_dim: Optional[int] = None
  261. # Block size of aphrodite
  262. page_size: Optional[int] = None
  263. # The data type of the paged kv cache
  264. data_type: torch.dtype = None
  265. # The data type of the query
  266. q_data_type: torch.dtype = None
  267. device: torch.device = torch.device("cuda")
  268. is_profile_run: bool = False
  269. def __post_init__(self):
  270. # Refer to
  271. # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
  272. supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
  273. if self.head_dim is not None and self.head_dim \
  274. not in supported_head_sizes:
  275. raise ValueError(
  276. f"Only {supported_head_sizes} are supported for head_dim,",
  277. f"received {self.head_dim}.")
  278. def begin_forward(self):
  279. if self.num_prefill_tokens > 0:
  280. if self.paged_kv_indices is None:
  281. return
  282. assert self.prefill_wrapper is not None
  283. assert self.query_start_loc is not None
  284. assert self.paged_kv_indices is not None
  285. assert self.paged_kv_indptr is not None
  286. assert self.paged_kv_last_page_len is not None
  287. assert self.block_table_bound is not None
  288. assert self.seq_lens_tensor is not None
  289. batch_size = self.query_start_loc.shape[0] - 1
  290. assert batch_size >= 0
  291. # We will use flash attention for profiling to
  292. # determine the number of blocks. Therefore,
  293. # we don't need to prepare the input for flashinfer for profile run.
  294. if not self.is_profile_run:
  295. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  296. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  297. self.device)
  298. self.block_table_bound = self.block_table_bound.to(self.device)
  299. self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
  300. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  301. self.prefill_wrapper.end_forward()
  302. self.prefill_wrapper.begin_forward(
  303. self.query_start_loc, self.paged_kv_indptr,
  304. self.paged_kv_indices, self.paged_kv_last_page_len,
  305. self.num_qo_heads, self.num_kv_heads, self.head_dim,
  306. self.page_size)
  307. else:
  308. assert self.paged_kv_indices is not None
  309. assert self.paged_kv_indptr is not None
  310. assert self.paged_kv_last_page_len is not None
  311. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  312. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  313. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  314. self.device)
  315. # handle model warmup path
  316. if self.block_table_bound is not None:
  317. self.block_table_bound = self.block_table_bound.to(self.device)
  318. if self.seq_lens_tensor is not None:
  319. self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
  320. assert self.decode_wrapper is not None
  321. self.decode_wrapper.end_forward()
  322. self.decode_wrapper.begin_forward(
  323. self.paged_kv_indptr,
  324. self.paged_kv_indices,
  325. self.paged_kv_last_page_len,
  326. self.num_qo_heads,
  327. self.num_kv_heads,
  328. self.head_dim,
  329. self.page_size,
  330. # Disable flashinfer's pos encoding and use aphrodite's rope.
  331. pos_encoding_mode="NONE",
  332. # kv-cache data type.
  333. data_type=self.data_type,
  334. # query data type.
  335. q_data_type=self.q_data_type)
  336. def asdict_zerocopy(self,
  337. skip_fields: Optional[Set[str]] = None
  338. ) -> Dict[str, Any]:
  339. if skip_fields is None:
  340. skip_fields = set()
  341. # We need to skip the prefill/decode_wrapper field since it cannot be
  342. # broadcasted with nccl when TP is enabled.
  343. skip_fields.add('prefill_wrapper')
  344. skip_fields.add('decode_wrapper')
  345. return super().asdict_zerocopy(skip_fields)
  346. @property
  347. def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
  348. # Currently chunked prefill is not supported
  349. if self.num_decode_tokens == 0:
  350. assert self.num_prefills > 0
  351. return self
  352. return None
  353. @property
  354. def decode_metadata(self) -> Optional["FlashInferMetadata"]:
  355. # Currently chunked prefill is not supported
  356. if self.num_prefills > 0:
  357. assert self.num_decode_tokens == 0, (
  358. "Chunked prefill is not supported with flashinfer yet.")
  359. return None
  360. return self
  361. def advance_step(
  362. self,
  363. model_input: "ModelInputForGPUWithSamplingMetadata",
  364. sampled_token_ids: Optional[torch.Tensor],
  365. block_size: int,
  366. num_seqs: int,
  367. num_queries: int,
  368. ):
  369. """
  370. Update metadata in-place to advance one decode step.
  371. """
  372. assert num_seqs > 0
  373. assert num_queries > 0
  374. assert model_input.attn_metadata is not None
  375. assert sampled_token_ids is not None
  376. # When using cudagraph, the num_seqs is padded to the next captured
  377. # batch sized, but num_queries tracks the actual number of requests in
  378. # the batch. For --enforce-eager mode, num_seqs == num_queries
  379. if num_seqs != num_queries:
  380. assert num_seqs > num_queries
  381. assert self.use_cuda_graph
  382. model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
  383. # Update GPU tensors
  384. ops.advance_step_flashinfer(
  385. num_seqs=num_seqs,
  386. num_queries=num_queries,
  387. block_size=block_size,
  388. input_tokens=model_input.input_tokens,
  389. sampled_token_ids=model_input.input_tokens,
  390. input_positions=model_input.input_positions,
  391. seq_lens=self.seq_lens_tensor,
  392. slot_mapping=self.slot_mapping,
  393. block_tables=self.block_tables,
  394. paged_kv_indices=self.paged_kv_indices,
  395. paged_kv_indptr=self.paged_kv_indptr,
  396. paged_kv_last_page_len=self.paged_kv_last_page_len,
  397. block_table_bound=self.block_table_bound)
  398. class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
  399. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  400. self.slot_mapping: List[int] = []
  401. self.prefill_seq_lens: List[int] = []
  402. self.context_lens: List[int] = []
  403. self.block_tables: List[List[int]] = []
  404. self.curr_seq_lens: List[int] = []
  405. self.num_prefills = 0
  406. self.num_prefill_tokens = 0
  407. self.num_decode_tokens = 0
  408. self.input_builder = input_builder
  409. self.runner = input_builder.runner
  410. self.sliding_window = input_builder.sliding_window
  411. self.block_size = input_builder.block_size
  412. self.use_v2_block_manager = (
  413. input_builder.scheduler_config.use_v2_block_manager)
  414. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  415. # for the precise definition of the following fields.
  416. # An example:
  417. # request 1, page indices [0, 5, 8]
  418. # request 2, page indices [1, 6, 7]
  419. # request 3, page indices [3, 4]
  420. # paged_kv_indices is a concatenation of page indices of all requests:
  421. # [0, 5, 8, 1, 6, 7, 3, 4]
  422. # paged_kv_indptr is used to index into paged_kv_indices:
  423. # [0, 3, 6, 8]
  424. self.paged_kv_indices: List[int] = []
  425. # 0 at the beginning of paged_kv_indptr indicates the start of the
  426. # first request’s page indices in the paged_kv_indices list.
  427. self.paged_kv_indptr: List[int] = [0]
  428. # paged_kv_last_page_len is the length of the last page of each request
  429. self.paged_kv_last_page_len: List[int] = []
  430. self.total_blocks = 0
  431. self.is_profile_run: bool = False
  432. def _add_seq_group(
  433. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  434. chunked_prefill_enabled: bool):
  435. """Add a sequence group to the metadata. Specifically update/append
  436. 1. context length.
  437. 2. block table.
  438. 3. slot mapping.
  439. """
  440. is_prompt = inter_data.is_prompt
  441. block_tables = inter_data.block_tables
  442. computed_block_nums = inter_data.computed_block_nums
  443. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  444. curr_sliding_window_block) in zip(
  445. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  446. inter_data.orig_seq_lens, inter_data.seq_lens,
  447. inter_data.query_lens, inter_data.context_lens,
  448. inter_data.curr_sliding_window_blocks):
  449. self.context_lens.append(context_len)
  450. if is_prompt:
  451. self.num_prefills += 1
  452. self.num_prefill_tokens += token_len
  453. self.prefill_seq_lens.append(seq_len)
  454. else:
  455. assert query_len == 1, (
  456. "seq_len: {}, context_len: {}, query_len: {}".format(
  457. seq_len, context_len, query_len))
  458. self.num_decode_tokens += query_len
  459. self.curr_seq_lens.append(curr_seq_len)
  460. # Compute block table.
  461. # TODO(sang): Combine chunked prefill and prefix caching by
  462. # only allowing multiple of block_size chunk size.
  463. # NOTE: This only works for oooooooxxx style attention.
  464. block_table = []
  465. if inter_data.prefix_cache_hit:
  466. block_table = computed_block_nums
  467. elif ((chunked_prefill_enabled or not is_prompt)
  468. and block_tables is not None):
  469. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  470. self.block_tables.append(block_table)
  471. is_profile_run = is_block_tables_empty(block_tables)
  472. # Compute slot mapping.
  473. start_idx = compute_slot_mapping_start_idx(
  474. is_prompt, query_len, context_len, self.sliding_window,
  475. self.use_v2_block_manager)
  476. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  477. seq_len, context_len, start_idx,
  478. self.block_size, inter_data.block_tables)
  479. # It is not necessary to add paged_kv_indices, paged_kv_indptr,
  480. # and paged_kv_last_page_len for profile run because we will
  481. # create dummy inputs.
  482. if is_profile_run:
  483. self.is_profile_run = is_profile_run
  484. return
  485. block_table = block_tables[seq_id]
  486. self._update_paged_kv_tensors(block_table, seq_len)
  487. def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
  488. # Get the number of valid blocks based on sequence length.
  489. # If seq_len = 16, block_size = 16,
  490. # block_table_bound is 1 with 1 valid block.
  491. # If seq_len = 15, block_size = 16,
  492. # block_table_bound is 0 + 1 with 1 valid block.
  493. self.total_blocks += len(block_table)
  494. block_table_bound = seq_len // self.block_size + 1 \
  495. if seq_len % self.block_size != 0 \
  496. else seq_len // self.block_size
  497. self.paged_kv_indices.extend(block_table[:block_table_bound])
  498. self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
  499. block_table_bound)
  500. last_page_len = seq_len % self.block_size
  501. if last_page_len == 0:
  502. last_page_len = self.block_size
  503. self.paged_kv_last_page_len.append(last_page_len)
  504. def build(self, seq_lens: List[int], query_lens: List[int],
  505. cuda_graph_pad_size: int, batch_size: int):
  506. """Build attention metadata with on-device tensors.
  507. Args:
  508. seq_lens: The maybe padded sequence lengths of the input sequences.
  509. query_lens: The query lengths of the input sequences.
  510. cuda_graph_pad_size: The padding size for cuda graph.
  511. -1 if cuda graph is not used.
  512. batch_size: The maybe padded batch size.
  513. """
  514. for inter_data in self.input_builder.inter_data_list:
  515. self._add_seq_group(inter_data,
  516. self.input_builder.chunked_prefill_enabled)
  517. device = self.runner.device
  518. use_captured_graph = cuda_graph_pad_size != -1
  519. max_query_len = max(query_lens)
  520. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  521. num_decode_tokens = self.num_decode_tokens
  522. if use_captured_graph:
  523. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  524. self.block_tables.extend([] * cuda_graph_pad_size)
  525. num_decode_tokens = batch_size
  526. # The shape of graph_block_tables is
  527. # [max batch size, max context len // block size].
  528. input_block_tables = self.runner.graph_block_tables[:batch_size]
  529. for i, block_table in enumerate(self.block_tables):
  530. if block_table:
  531. input_block_tables[i, :len(block_table)] = block_table
  532. block_tables = torch.from_numpy(input_block_tables).to(
  533. device, non_blocking=True)
  534. last_paged_kv_indptr = self.paged_kv_indptr[-1]
  535. self.paged_kv_indptr.extend([last_paged_kv_indptr] *
  536. cuda_graph_pad_size)
  537. self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
  538. else:
  539. block_tables = make_tensor_with_pad(
  540. self.block_tables,
  541. pad=0,
  542. dtype=torch.int,
  543. device=device,
  544. )
  545. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  546. assert device is not None
  547. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  548. self.runner.pin_memory)
  549. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  550. self.runner.pin_memory)
  551. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  552. device, self.runner.pin_memory)
  553. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  554. dtype=torch.int32,
  555. device=device)
  556. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  557. dtype=torch.int32,
  558. device=device)
  559. torch.cumsum(seq_lens_tensor,
  560. dim=0,
  561. dtype=seq_start_loc.dtype,
  562. out=seq_start_loc[1:])
  563. torch.cumsum(query_lens_tensor,
  564. dim=0,
  565. dtype=query_start_loc.dtype,
  566. out=query_start_loc[1:])
  567. if len(self.paged_kv_indptr) > 0:
  568. # extend to the maximum number of blocks as returned by the
  569. # scheduler
  570. self.paged_kv_indices.extend(
  571. [0] * (self.total_blocks - len(self.paged_kv_indices)))
  572. paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
  573. device="cpu",
  574. dtype=torch.int)
  575. paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
  576. device="cpu",
  577. dtype=torch.int)
  578. paged_kv_last_page_len_tensor = torch.tensor(
  579. self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
  580. block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
  581. 1,
  582. device="cpu",
  583. dtype=torch.int)
  584. else:
  585. paged_kv_indices_tensor = None
  586. paged_kv_indptr_tensor = None
  587. paged_kv_last_page_len_tensor = None
  588. block_table_bound_tensor = None
  589. if self.runner.kv_cache_dtype.startswith("fp8"):
  590. kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
  591. self.runner.kv_cache_dtype)
  592. else:
  593. kv_cache_dtype = get_kv_cache_torch_dtype(
  594. self.runner.kv_cache_dtype, self.runner.model_config.dtype)
  595. return FlashInferMetadata(
  596. num_prefills=self.num_prefills,
  597. slot_mapping=slot_mapping_tensor,
  598. num_prefill_tokens=self.num_prefill_tokens,
  599. num_decode_tokens=num_decode_tokens,
  600. max_prefill_seq_len=max_prefill_seq_len,
  601. block_tables=block_tables,
  602. paged_kv_indptr=paged_kv_indptr_tensor,
  603. paged_kv_indices=paged_kv_indices_tensor,
  604. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  605. block_table_bound=block_table_bound_tensor,
  606. seq_lens_tensor=seq_lens_tensor,
  607. num_qo_heads=self.runner.model_config.get_num_attention_heads(
  608. self.runner.parallel_config),
  609. num_kv_heads=self.runner.model_config.get_num_kv_heads(
  610. self.runner.parallel_config),
  611. head_dim=self.runner.model_config.get_head_size(),
  612. page_size=self.block_size,
  613. seq_start_loc=seq_start_loc,
  614. query_start_loc=query_start_loc,
  615. device=device,
  616. data_type=kv_cache_dtype,
  617. q_data_type=self.runner.model_config.dtype,
  618. use_cuda_graph=use_captured_graph,
  619. is_profile_run=self.is_profile_run)
  620. class FlashInferImpl(AttentionImpl):
  621. def __init__(
  622. self,
  623. num_heads: int,
  624. head_size: int,
  625. scale: float,
  626. num_kv_heads: int,
  627. alibi_slopes: Optional[List[float]],
  628. sliding_window: Optional[int],
  629. kv_cache_dtype: str,
  630. blocksparse_params: Optional[Dict[str, Any]] = None,
  631. logits_soft_cap: Optional[float] = None,
  632. ) -> None:
  633. self.num_heads = num_heads
  634. self.head_size = head_size
  635. self.scale = float(scale)
  636. self.num_kv_heads = num_kv_heads
  637. if alibi_slopes is not None:
  638. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  639. self.alibi_slopes = alibi_slopes
  640. if sliding_window is not None:
  641. raise ValueError("Sliding window is not supported in FlashInfer.")
  642. self.sliding_window = (-1, -1)
  643. self.kv_cache_dtype = kv_cache_dtype
  644. self.logits_soft_cap = logits_soft_cap
  645. assert self.num_heads % self.num_kv_heads == 0
  646. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  647. def forward(
  648. self,
  649. query: torch.Tensor,
  650. key: torch.Tensor,
  651. value: torch.Tensor,
  652. kv_cache: Optional[torch.Tensor],
  653. attn_metadata: FlashInferMetadata,
  654. k_scale: float = 1.0,
  655. v_scale: float = 1.0,
  656. attn_type: AttentionType = AttentionType.DECODER,
  657. ) -> torch.Tensor:
  658. assert k_scale == 1.0 and v_scale == 1.0, (
  659. "key/v_scale is not supported in FlashInfer.")
  660. if attn_type != AttentionType.DECODER:
  661. raise NotImplementedError("Encoder self-attention and "
  662. "encoder/decoder cross-attention "
  663. "are not implemented for "
  664. "FlashInferImpl")
  665. num_tokens, hidden_size = query.shape
  666. query = query.view(-1, self.num_heads, self.head_size)
  667. key = key.view(-1, self.num_kv_heads, self.head_size)
  668. value = value.view(-1, self.num_kv_heads, self.head_size)
  669. if attn_metadata.num_prefill_tokens > 0:
  670. assert attn_metadata.num_decode_tokens == 0, (
  671. "Chunked prefill is not supported with flashinfer yet.")
  672. if attn_metadata.num_decode_tokens > 0:
  673. assert attn_metadata.num_prefill_tokens == 0, (
  674. "Chunked prefill is not supported with flashinfer yet.")
  675. if kv_cache is not None:
  676. # Use the same reshape and cache kernel as flash attention.
  677. ops.reshape_and_cache_flash(
  678. key,
  679. value,
  680. kv_cache[:, 0],
  681. kv_cache[:, 1],
  682. attn_metadata.slot_mapping.flatten(),
  683. self.kv_cache_dtype,
  684. k_scale,
  685. v_scale,
  686. )
  687. # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
  688. # to process the cache when the kv_cache_dtype is fp8
  689. if self.kv_cache_dtype.startswith("fp8"):
  690. torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
  691. self.kv_cache_dtype)
  692. kv_cache = kv_cache.view(torch_dtype)
  693. query = query.contiguous(
  694. ) # Flashinfer requires query to be contiguous
  695. if prefill_meta := attn_metadata.prefill_metadata:
  696. # We will use flash attention for prefill
  697. # when kv_cache is not provided.
  698. # This happens when aphrodite runs the profiling to
  699. # determine the number of blocks.
  700. if kv_cache is None:
  701. output = torch.ops.aphrodite.flash_attn_varlen_func(
  702. q=query,
  703. k=key,
  704. v=value,
  705. cu_seqlens_q=prefill_meta.seq_start_loc,
  706. cu_seqlens_k=prefill_meta.seq_start_loc,
  707. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  708. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  709. softmax_scale=self.scale,
  710. causal=True,
  711. window_size=self.sliding_window,
  712. alibi_slopes=self.alibi_slopes,
  713. )
  714. else:
  715. assert prefill_meta is not None
  716. assert prefill_meta.prefill_wrapper is not None
  717. output = prefill_meta.prefill_wrapper.forward(
  718. query,
  719. kv_cache,
  720. logits_soft_cap=self.logits_soft_cap,
  721. causal=True)
  722. else:
  723. assert attn_metadata.decode_metadata is not None
  724. assert attn_metadata.decode_metadata.decode_wrapper is not None
  725. output = attn_metadata.decode_metadata.decode_wrapper.forward(
  726. query,
  727. kv_cache,
  728. sm_scale=self.scale,
  729. logits_soft_cap=self.logits_soft_cap,
  730. k_scale=k_scale,
  731. v_scale=v_scale)
  732. return output.view(num_tokens, hidden_size)