flashinfer.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833
  1. from contextlib import contextmanager
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
  4. try:
  5. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  6. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  7. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  8. import aphrodite.attention.backends.flash_attn # noqa
  9. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  10. except ImportError:
  11. BatchDecodeWithPagedKVCacheWrapper = None
  12. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  13. BatchPrefillWithPagedKVCacheWrapper = None
  14. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  15. import torch
  16. from aphrodite import _custom_ops as ops
  17. from aphrodite.attention.backends.abstract import (AttentionBackend,
  18. AttentionImpl,
  19. AttentionMetadata,
  20. AttentionMetadataBuilder,
  21. AttentionState,
  22. AttentionType)
  23. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  24. compute_slot_mapping,
  25. compute_slot_mapping_start_idx,
  26. is_block_tables_empty)
  27. from aphrodite.attention.ops.paged_attn import PagedAttention
  28. from aphrodite.common.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
  29. make_tensor_with_pad)
  30. if TYPE_CHECKING:
  31. from aphrodite.worker.model_runner import (
  32. ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata)
  33. class FlashInferBackend(AttentionBackend):
  34. @staticmethod
  35. def get_name() -> str:
  36. return "flashinfer"
  37. @staticmethod
  38. def get_impl_cls() -> Type["FlashInferImpl"]:
  39. return FlashInferImpl
  40. @staticmethod
  41. def get_metadata_cls() -> Type["AttentionMetadata"]:
  42. return FlashInferMetadata
  43. @staticmethod
  44. def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
  45. return FlashInferMetadataBuilder
  46. @staticmethod
  47. def get_state_cls() -> Type["FlashInferState"]:
  48. return FlashInferState
  49. @staticmethod
  50. def get_kv_cache_shape(
  51. num_blocks: int,
  52. block_size: int,
  53. num_kv_heads: int,
  54. head_size: int,
  55. ) -> Tuple[int, ...]:
  56. return (num_blocks, 2, block_size, num_kv_heads, head_size)
  57. @staticmethod
  58. def swap_blocks(
  59. src_kv_cache: torch.Tensor,
  60. dst_kv_cache: torch.Tensor,
  61. src_to_dst: torch.Tensor,
  62. ) -> None:
  63. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  64. @staticmethod
  65. def copy_blocks(
  66. kv_caches: List[torch.Tensor],
  67. src_to_dists: torch.Tensor,
  68. ) -> None:
  69. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  70. @staticmethod
  71. def get_supported_head_sizes() -> List[int]:
  72. return [64, 128, 256]
  73. @staticmethod
  74. def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
  75. if kv_cache_dtype in ("fp8", "fp8_e4m3"):
  76. return torch.float8_e4m3fn
  77. elif kv_cache_dtype == "fp8_e5m2":
  78. return torch.float8_e5m2
  79. else:
  80. raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
  81. class FlashInferState(AttentionState):
  82. def __init__(self, runner):
  83. self.runner = runner
  84. self._is_graph_capturing = False
  85. self._workspace_buffer = None
  86. self._decode_wrapper = None
  87. self._prefill_wrapper = None
  88. def _get_workspace_buffer(self):
  89. if self._workspace_buffer is None:
  90. self._workspace_buffer = torch.empty(
  91. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  92. dtype=torch.uint8,
  93. device=self.runner.device)
  94. return self._workspace_buffer
  95. def _get_prefill_wrapper(self):
  96. if self._prefill_wrapper is None:
  97. self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
  98. self._get_workspace_buffer(), "NHD")
  99. return self._prefill_wrapper
  100. def _get_decode_wrapper(self):
  101. if self._decode_wrapper is None:
  102. num_qo_heads = (self.runner.model_config.get_num_attention_heads(
  103. self.runner.parallel_config))
  104. num_kv_heads = self.runner.model_config.get_num_kv_heads(
  105. self.runner.parallel_config)
  106. use_tensor_cores = num_qo_heads // num_kv_heads > 4
  107. self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
  108. self._get_workspace_buffer(),
  109. "NHD",
  110. use_tensor_cores=use_tensor_cores)
  111. return self._decode_wrapper
  112. @contextmanager
  113. def graph_capture(self, max_batch_size: int):
  114. self._is_graph_capturing = True
  115. self._graph_decode_wrapper = None
  116. self._graph_slot_mapping = torch.full((max_batch_size, ),
  117. PAD_SLOT_ID,
  118. dtype=torch.long,
  119. device=self.runner.device)
  120. self._graph_seq_lens = torch.ones(max_batch_size,
  121. dtype=torch.int32,
  122. device=self.runner.device)
  123. self._graph_block_tables = torch.from_numpy(
  124. self.runner.graph_block_tables).to(device=self.runner.device)
  125. self._graph_decode_workspace_buffer = self._get_workspace_buffer()
  126. self._graph_indices_buffer = torch.empty(
  127. max_batch_size * self.runner.cache_config.num_gpu_blocks,
  128. dtype=torch.int32,
  129. device=self.runner.device)
  130. self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
  131. dtype=torch.int32,
  132. device=self.runner.device)
  133. self._graph_last_page_len_buffer = torch.empty(
  134. max_batch_size, dtype=torch.int32, device=self.runner.device)
  135. yield
  136. self._is_graph_capturing = False
  137. del self._graph_slot_mapping
  138. del self._graph_seq_lens
  139. del self._graph_block_tables
  140. del self._graph_decode_workspace_buffer
  141. del self._graph_indices_buffer
  142. del self._graph_indptr_buffer
  143. del self._graph_last_page_len_buffer
  144. del self._graph_decode_wrapper
  145. def graph_clone(self, batch_size: int):
  146. assert self._is_graph_capturing
  147. state = self.__class__(self.runner)
  148. state._workspace_buffer = self._graph_decode_workspace_buffer
  149. state._decode_wrapper = self._graph_decode_wrapper
  150. state._prefill_wrapper = self._get_prefill_wrapper()
  151. return state
  152. def graph_capture_get_metadata_for_batch(
  153. self, batch_size: int, is_encoder_decoder_model: bool = False):
  154. assert self._is_graph_capturing
  155. _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
  156. _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
  157. num_qo_heads = (self.runner.model_config.get_num_attention_heads(
  158. self.runner.parallel_config))
  159. num_kv_heads = self.runner.model_config.get_num_kv_heads(
  160. self.runner.parallel_config)
  161. use_tensor_cores = num_qo_heads // num_kv_heads > 4
  162. self._graph_decode_wrapper = \
  163. CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
  164. self._graph_decode_workspace_buffer, _indptr_buffer,
  165. self._graph_indices_buffer, _last_page_len_buffer, "NHD",
  166. use_tensor_cores)
  167. if self.runner.kv_cache_dtype.startswith("fp8"):
  168. kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
  169. self.runner.kv_cache_dtype)
  170. else:
  171. kv_cache_dtype = get_kv_cache_torch_dtype(
  172. self.runner.kv_cache_dtype, self.runner.model_config.dtype)
  173. paged_kv_indptr_tensor_host = torch.arange(0,
  174. batch_size + 1,
  175. dtype=torch.int32)
  176. paged_kv_indices_tensor_host = torch.arange(0,
  177. batch_size,
  178. dtype=torch.int32)
  179. paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
  180. self.runner.block_size,
  181. dtype=torch.int32)
  182. query_start_loc_host = torch.arange(0,
  183. batch_size + 1,
  184. dtype=torch.int32)
  185. attn_metadata = self.runner.attn_backend.make_metadata(
  186. num_prefills=0,
  187. slot_mapping=self._graph_slot_mapping[:batch_size],
  188. num_prefill_tokens=0,
  189. num_decode_tokens=batch_size,
  190. max_prefill_seq_len=0,
  191. block_tables=self._graph_block_tables,
  192. paged_kv_indptr=paged_kv_indptr_tensor_host,
  193. paged_kv_indices=paged_kv_indices_tensor_host,
  194. paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
  195. num_qo_heads=num_qo_heads,
  196. num_kv_heads=num_kv_heads,
  197. head_dim=self.runner.model_config.get_head_size(),
  198. page_size=self.runner.block_size,
  199. seq_start_loc=None,
  200. query_start_loc=query_start_loc_host,
  201. device=self.runner.device,
  202. data_type=kv_cache_dtype,
  203. q_data_type=self.runner.model_config.dtype,
  204. use_cuda_graph=True,
  205. decode_wrapper=self._graph_decode_wrapper,
  206. prefill_wrapper=None)
  207. attn_metadata.begin_forward()
  208. return attn_metadata
  209. def get_graph_input_buffers(self,
  210. attn_metadata,
  211. is_encoder_decoder_model: bool = False):
  212. return {
  213. "slot_mapping": attn_metadata.slot_mapping,
  214. }
  215. def prepare_graph_input_buffers(self,
  216. input_buffers,
  217. attn_metadata,
  218. is_encoder_decoder_model: bool = False):
  219. return
  220. def begin_forward(self, model_input):
  221. assert not self._is_graph_capturing
  222. state = self
  223. if model_input.attn_metadata.use_cuda_graph:
  224. batch_size = model_input.input_tokens.shape[0]
  225. state = (self.runner.graph_runners[model_input.virtual_engine]
  226. [batch_size].attn_state)
  227. model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
  228. )
  229. model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
  230. model_input.attn_metadata.begin_forward()
  231. @dataclass
  232. class FlashInferMetadata(AttentionMetadata):
  233. # Maximum sequence length among prefill batch. 0 if there are decoding
  234. # requests only.
  235. max_prefill_seq_len: int
  236. use_cuda_graph: bool = True
  237. prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
  238. decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
  239. # Metadata for the prefill stage
  240. seq_start_loc: Optional[torch.Tensor] = None
  241. query_start_loc: Optional[torch.Tensor] = None
  242. block_tables: Optional[torch.Tensor] = None
  243. # used for GPU in-place advance_step
  244. seq_lens_tensor: Optional[torch.Tensor] = None
  245. block_table_bound: Optional[torch.Tensor] = None
  246. # An example for paged_kv_indices, paged_kv_indptr:
  247. # request 1, page indices [0, 5, 8]
  248. # request 2, page indices [1, 6, 7]
  249. # request 3, page indices [3, 4]
  250. # paged_kv_indices is a concatenation of page indices of all requests:
  251. # [0, 5, 8, 1, 6, 7, 3, 4]
  252. # paged_kv_indptr is used to index into paged_kv_indices:
  253. # [0, 3, 6, 8]
  254. # The indptr of the paged kv cache, shape: [batch_size + 1]
  255. paged_kv_indptr: Optional[torch.Tensor] = None
  256. # The page indices of the paged kv cache
  257. paged_kv_indices: Optional[torch.Tensor] = None
  258. # The number of entries in the last page of each request in
  259. # the paged kv cache, shape: [batch_size]
  260. paged_kv_last_page_len: Optional[torch.Tensor] = None
  261. # The number of query/output heads
  262. num_qo_heads: Optional[int] = None
  263. # The number of key/value heads
  264. num_kv_heads: Optional[int] = None
  265. # The dimension of the attention heads
  266. head_dim: Optional[int] = None
  267. # Block size of aphrodite
  268. page_size: Optional[int] = None
  269. # The data type of the paged kv cache
  270. data_type: torch.dtype = None
  271. # The data type of the query
  272. q_data_type: torch.dtype = None
  273. device: torch.device = torch.device("cuda")
  274. is_profile_run: bool = False
  275. def __post_init__(self):
  276. # Refer to
  277. # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
  278. supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
  279. if self.head_dim is not None and self.head_dim \
  280. not in supported_head_sizes:
  281. raise ValueError(
  282. f"Only {supported_head_sizes} are supported for head_dim,",
  283. f"received {self.head_dim}.")
  284. def begin_forward(self):
  285. if self.num_prefill_tokens > 0:
  286. if self.paged_kv_indices is None:
  287. return
  288. assert self.prefill_wrapper is not None
  289. assert self.query_start_loc is not None
  290. assert self.paged_kv_indices is not None
  291. assert self.paged_kv_indptr is not None
  292. assert self.paged_kv_last_page_len is not None
  293. assert self.block_table_bound is not None
  294. assert self.seq_lens_tensor is not None
  295. batch_size = self.query_start_loc.shape[0] - 1
  296. assert batch_size >= 0
  297. # We will use flash attention for profiling to
  298. # determine the number of blocks. Therefore,
  299. # we don't need to prepare the input for flashinfer for profile run.
  300. if not self.is_profile_run:
  301. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  302. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  303. self.device)
  304. self.block_table_bound = self.block_table_bound.to(self.device)
  305. self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
  306. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  307. self.prefill_wrapper.end_forward()
  308. self.prefill_wrapper.begin_forward(
  309. self.query_start_loc, self.paged_kv_indptr,
  310. self.paged_kv_indices, self.paged_kv_last_page_len,
  311. self.num_qo_heads, self.num_kv_heads, self.head_dim,
  312. self.page_size)
  313. else:
  314. assert self.paged_kv_indices is not None
  315. assert self.paged_kv_indptr is not None
  316. assert self.paged_kv_last_page_len is not None
  317. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  318. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  319. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  320. self.device)
  321. # handle model warmup path
  322. if self.block_table_bound is not None:
  323. self.block_table_bound = self.block_table_bound.to(self.device)
  324. if self.seq_lens_tensor is not None:
  325. self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
  326. assert self.decode_wrapper is not None
  327. self.decode_wrapper.end_forward()
  328. self.decode_wrapper.begin_forward(
  329. self.paged_kv_indptr,
  330. self.paged_kv_indices,
  331. self.paged_kv_last_page_len,
  332. self.num_qo_heads,
  333. self.num_kv_heads,
  334. self.head_dim,
  335. self.page_size,
  336. # Disable flashinfer's pos encoding and use aphrodite's rope.
  337. pos_encoding_mode="NONE",
  338. # kv-cache data type.
  339. data_type=self.data_type,
  340. # query data type.
  341. q_data_type=self.q_data_type)
  342. def asdict_zerocopy(self,
  343. skip_fields: Optional[Set[str]] = None
  344. ) -> Dict[str, Any]:
  345. if skip_fields is None:
  346. skip_fields = set()
  347. # We need to skip the prefill/decode_wrapper field since it cannot be
  348. # broadcasted with nccl when TP is enabled.
  349. skip_fields.add('prefill_wrapper')
  350. skip_fields.add('decode_wrapper')
  351. return super().asdict_zerocopy(skip_fields)
  352. @property
  353. def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
  354. # Currently chunked prefill is not supported
  355. if self.num_decode_tokens == 0:
  356. assert self.num_prefills > 0
  357. return self
  358. return None
  359. @property
  360. def decode_metadata(self) -> Optional["FlashInferMetadata"]:
  361. # Currently chunked prefill is not supported
  362. if self.num_prefills > 0:
  363. assert self.num_decode_tokens == 0, (
  364. "Chunked prefill is not supported with flashinfer yet.")
  365. return None
  366. return self
  367. def advance_step(
  368. self,
  369. model_input: "ModelInputForGPUWithSamplingMetadata",
  370. sampled_token_ids: Optional[torch.Tensor],
  371. block_size: int,
  372. num_seqs: int,
  373. num_queries: int,
  374. ):
  375. """
  376. Update metadata in-place to advance one decode step.
  377. """
  378. assert num_seqs > 0
  379. assert num_queries > 0
  380. assert model_input.attn_metadata is not None
  381. assert sampled_token_ids is not None
  382. # When using cudagraph, the num_seqs is padded to the next captured
  383. # batch sized, but num_queries tracks the actual number of requests in
  384. # the batch. For --enforce-eager mode, num_seqs == num_queries
  385. if num_seqs != num_queries:
  386. assert num_seqs > num_queries
  387. assert self.use_cuda_graph
  388. model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
  389. # Update GPU tensors
  390. ops.advance_step_flashinfer(
  391. num_seqs=num_seqs,
  392. num_queries=num_queries,
  393. block_size=block_size,
  394. input_tokens=model_input.input_tokens,
  395. sampled_token_ids=model_input.input_tokens,
  396. input_positions=model_input.input_positions,
  397. seq_lens=self.seq_lens_tensor,
  398. slot_mapping=self.slot_mapping,
  399. block_tables=self.block_tables,
  400. paged_kv_indices=self.paged_kv_indices,
  401. paged_kv_indptr=self.paged_kv_indptr,
  402. paged_kv_last_page_len=self.paged_kv_last_page_len,
  403. block_table_bound=self.block_table_bound)
  404. class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
  405. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  406. self.slot_mapping: List[int] = []
  407. self.prefill_seq_lens: List[int] = []
  408. self.context_lens: List[int] = []
  409. self.block_tables: List[List[int]] = []
  410. self.curr_seq_lens: List[int] = []
  411. self.num_prefills = 0
  412. self.num_prefill_tokens = 0
  413. self.num_decode_tokens = 0
  414. self.input_builder = input_builder
  415. self.runner = input_builder.runner
  416. self.sliding_window = input_builder.sliding_window
  417. self.block_size = input_builder.block_size
  418. self.use_v2_block_manager = (
  419. input_builder.scheduler_config.use_v2_block_manager)
  420. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  421. # for the precise definition of the following fields.
  422. # An example:
  423. # request 1, page indices [0, 5, 8]
  424. # request 2, page indices [1, 6, 7]
  425. # request 3, page indices [3, 4]
  426. # paged_kv_indices is a concatenation of page indices of all requests:
  427. # [0, 5, 8, 1, 6, 7, 3, 4]
  428. # paged_kv_indptr is used to index into paged_kv_indices:
  429. # [0, 3, 6, 8]
  430. self.paged_kv_indices: List[int] = []
  431. # 0 at the beginning of paged_kv_indptr indicates the start of the
  432. # first request’s page indices in the paged_kv_indices list.
  433. self.paged_kv_indptr: List[int] = [0]
  434. # paged_kv_last_page_len is the length of the last page of each request
  435. self.paged_kv_last_page_len: List[int] = []
  436. self.total_blocks = 0
  437. self.is_profile_run: bool = False
  438. def _add_seq_group(
  439. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  440. chunked_prefill_enabled: bool):
  441. """Add a sequence group to the metadata. Specifically update/append
  442. 1. context length.
  443. 2. block table.
  444. 3. slot mapping.
  445. """
  446. is_prompt = inter_data.is_prompt
  447. block_tables = inter_data.block_tables
  448. computed_block_nums = inter_data.computed_block_nums
  449. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  450. curr_sliding_window_block) in zip(
  451. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  452. inter_data.orig_seq_lens, inter_data.seq_lens,
  453. inter_data.query_lens, inter_data.context_lens,
  454. inter_data.curr_sliding_window_blocks):
  455. self.context_lens.append(context_len)
  456. if is_prompt:
  457. self.num_prefills += 1
  458. self.num_prefill_tokens += token_len
  459. self.prefill_seq_lens.append(seq_len)
  460. else:
  461. assert query_len == 1, (
  462. "seq_len: {}, context_len: {}, query_len: {}".format(
  463. seq_len, context_len, query_len))
  464. self.num_decode_tokens += query_len
  465. self.curr_seq_lens.append(curr_seq_len)
  466. # Compute block table.
  467. # TODO(sang): Combine chunked prefill and prefix caching by
  468. # only allowing multiple of block_size chunk size.
  469. # NOTE: This only works for oooooooxxx style attention.
  470. block_table = []
  471. if inter_data.prefix_cache_hit:
  472. block_table = computed_block_nums
  473. elif ((chunked_prefill_enabled or not is_prompt)
  474. and block_tables is not None):
  475. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  476. self.block_tables.append(block_table)
  477. is_profile_run = is_block_tables_empty(block_tables)
  478. # Compute slot mapping.
  479. start_idx = compute_slot_mapping_start_idx(
  480. is_prompt, query_len, context_len, self.sliding_window,
  481. self.use_v2_block_manager)
  482. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  483. seq_len, context_len, start_idx,
  484. self.block_size, inter_data.block_tables)
  485. # It is not necessary to add paged_kv_indices, paged_kv_indptr,
  486. # and paged_kv_last_page_len for profile run because we will
  487. # create dummy inputs.
  488. if is_profile_run:
  489. self.is_profile_run = is_profile_run
  490. return
  491. block_table = block_tables[seq_id]
  492. self._update_paged_kv_tensors(block_table, seq_len)
  493. def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
  494. # Get the number of valid blocks based on sequence length.
  495. # If seq_len = 16, block_size = 16,
  496. # block_table_bound is 1 with 1 valid block.
  497. # If seq_len = 15, block_size = 16,
  498. # block_table_bound is 0 + 1 with 1 valid block.
  499. self.total_blocks += len(block_table)
  500. block_table_bound = seq_len // self.block_size + 1 \
  501. if seq_len % self.block_size != 0 \
  502. else seq_len // self.block_size
  503. self.paged_kv_indices.extend(block_table[:block_table_bound])
  504. self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
  505. block_table_bound)
  506. last_page_len = seq_len % self.block_size
  507. if last_page_len == 0:
  508. last_page_len = self.block_size
  509. self.paged_kv_last_page_len.append(last_page_len)
  510. def build(self, seq_lens: List[int], query_lens: List[int],
  511. cuda_graph_pad_size: int, batch_size: int):
  512. """Build attention metadata with on-device tensors.
  513. Args:
  514. seq_lens: The maybe padded sequence lengths of the input sequences.
  515. query_lens: The query lengths of the input sequences.
  516. cuda_graph_pad_size: The padding size for cuda graph.
  517. -1 if cuda graph is not used.
  518. batch_size: The maybe padded batch size.
  519. """
  520. for inter_data in self.input_builder.inter_data_list:
  521. self._add_seq_group(inter_data,
  522. self.input_builder.chunked_prefill_enabled)
  523. device = self.runner.device
  524. use_captured_graph = cuda_graph_pad_size != -1
  525. max_query_len = max(query_lens)
  526. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  527. num_decode_tokens = self.num_decode_tokens
  528. if use_captured_graph:
  529. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  530. self.block_tables.extend([] * cuda_graph_pad_size)
  531. num_decode_tokens = batch_size
  532. # The shape of graph_block_tables is
  533. # [max batch size, max context len // block size].
  534. input_block_tables = self.runner.graph_block_tables[:batch_size]
  535. max_blocks = input_block_tables.shape[1]
  536. for i, block_table in enumerate(self.block_tables):
  537. if block_table:
  538. num_blocks = len(block_table)
  539. if num_blocks <= max_blocks:
  540. input_block_tables[i, :num_blocks] = block_table
  541. else:
  542. # It may be possible to have more blocks allocated due
  543. # to lookahead slots of multi-step, however, they are
  544. # not used anyway, so can be safely ignored.
  545. input_block_tables[
  546. i, :max_blocks] = block_table[:max_blocks]
  547. block_tables = torch.from_numpy(input_block_tables).to(
  548. device, non_blocking=True)
  549. last_paged_kv_indptr = self.paged_kv_indptr[-1]
  550. self.paged_kv_indptr.extend([last_paged_kv_indptr] *
  551. cuda_graph_pad_size)
  552. self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
  553. else:
  554. block_tables = make_tensor_with_pad(
  555. self.block_tables,
  556. pad=0,
  557. dtype=torch.int,
  558. device=device,
  559. )
  560. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  561. assert device is not None
  562. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  563. self.runner.pin_memory)
  564. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  565. self.runner.pin_memory)
  566. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  567. device, self.runner.pin_memory)
  568. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  569. dtype=torch.int32,
  570. device=device)
  571. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  572. dtype=torch.int32,
  573. device=device)
  574. torch.cumsum(seq_lens_tensor,
  575. dim=0,
  576. dtype=seq_start_loc.dtype,
  577. out=seq_start_loc[1:])
  578. torch.cumsum(query_lens_tensor,
  579. dim=0,
  580. dtype=query_start_loc.dtype,
  581. out=query_start_loc[1:])
  582. if len(self.paged_kv_indptr) > 0:
  583. # extend to the maximum number of blocks as returned by the
  584. # scheduler
  585. self.paged_kv_indices.extend(
  586. [0] * (self.total_blocks - len(self.paged_kv_indices)))
  587. paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
  588. device="cpu",
  589. dtype=torch.int)
  590. paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
  591. device="cpu",
  592. dtype=torch.int)
  593. paged_kv_last_page_len_tensor = torch.tensor(
  594. self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
  595. block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
  596. 1,
  597. device="cpu",
  598. dtype=torch.int)
  599. else:
  600. paged_kv_indices_tensor = None
  601. paged_kv_indptr_tensor = None
  602. paged_kv_last_page_len_tensor = None
  603. block_table_bound_tensor = None
  604. if self.runner.kv_cache_dtype.startswith("fp8"):
  605. kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
  606. self.runner.kv_cache_dtype)
  607. else:
  608. kv_cache_dtype = get_kv_cache_torch_dtype(
  609. self.runner.kv_cache_dtype, self.runner.model_config.dtype)
  610. return FlashInferMetadata(
  611. num_prefills=self.num_prefills,
  612. slot_mapping=slot_mapping_tensor,
  613. num_prefill_tokens=self.num_prefill_tokens,
  614. num_decode_tokens=num_decode_tokens,
  615. max_prefill_seq_len=max_prefill_seq_len,
  616. block_tables=block_tables,
  617. paged_kv_indptr=paged_kv_indptr_tensor,
  618. paged_kv_indices=paged_kv_indices_tensor,
  619. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  620. block_table_bound=block_table_bound_tensor,
  621. seq_lens_tensor=seq_lens_tensor,
  622. num_qo_heads=self.runner.model_config.get_num_attention_heads(
  623. self.runner.parallel_config),
  624. num_kv_heads=self.runner.model_config.get_num_kv_heads(
  625. self.runner.parallel_config),
  626. head_dim=self.runner.model_config.get_head_size(),
  627. page_size=self.block_size,
  628. seq_start_loc=seq_start_loc,
  629. query_start_loc=query_start_loc,
  630. device=device,
  631. data_type=kv_cache_dtype,
  632. q_data_type=self.runner.model_config.dtype,
  633. use_cuda_graph=use_captured_graph,
  634. is_profile_run=self.is_profile_run)
  635. class FlashInferImpl(AttentionImpl):
  636. def __init__(
  637. self,
  638. num_heads: int,
  639. head_size: int,
  640. scale: float,
  641. num_kv_heads: int,
  642. alibi_slopes: Optional[List[float]],
  643. sliding_window: Optional[int],
  644. kv_cache_dtype: str,
  645. blocksparse_params: Optional[Dict[str, Any]] = None,
  646. logits_soft_cap: Optional[float] = None,
  647. ) -> None:
  648. self.num_heads = num_heads
  649. self.head_size = head_size
  650. self.scale = float(scale)
  651. self.num_kv_heads = num_kv_heads
  652. if alibi_slopes is not None:
  653. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  654. self.alibi_slopes = alibi_slopes
  655. if sliding_window is not None:
  656. raise ValueError("Sliding window is not supported in FlashInfer.")
  657. self.sliding_window = (-1, -1)
  658. self.kv_cache_dtype = kv_cache_dtype
  659. self.logits_soft_cap = logits_soft_cap
  660. assert self.num_heads % self.num_kv_heads == 0
  661. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  662. def forward(
  663. self,
  664. query: torch.Tensor,
  665. key: torch.Tensor,
  666. value: torch.Tensor,
  667. kv_cache: Optional[torch.Tensor],
  668. attn_metadata: FlashInferMetadata,
  669. k_scale: float = 1.0,
  670. v_scale: float = 1.0,
  671. attn_type: AttentionType = AttentionType.DECODER,
  672. ) -> torch.Tensor:
  673. assert k_scale == 1.0 and v_scale == 1.0, (
  674. "key/v_scale is not supported in FlashInfer.")
  675. if attn_type != AttentionType.DECODER:
  676. raise NotImplementedError("Encoder self-attention and "
  677. "encoder/decoder cross-attention "
  678. "are not implemented for "
  679. "FlashInferImpl")
  680. num_tokens, hidden_size = query.shape
  681. query = query.view(-1, self.num_heads, self.head_size)
  682. key = key.view(-1, self.num_kv_heads, self.head_size)
  683. value = value.view(-1, self.num_kv_heads, self.head_size)
  684. if attn_metadata.num_prefill_tokens > 0:
  685. assert attn_metadata.num_decode_tokens == 0, (
  686. "Chunked prefill is not supported with flashinfer yet.")
  687. if attn_metadata.num_decode_tokens > 0:
  688. assert attn_metadata.num_prefill_tokens == 0, (
  689. "Chunked prefill is not supported with flashinfer yet.")
  690. if kv_cache is not None:
  691. # Use the same reshape and cache kernel as flash attention.
  692. ops.reshape_and_cache_flash(
  693. key,
  694. value,
  695. kv_cache[:, 0],
  696. kv_cache[:, 1],
  697. attn_metadata.slot_mapping.flatten(),
  698. self.kv_cache_dtype,
  699. k_scale,
  700. v_scale,
  701. )
  702. # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
  703. # to process the cache when the kv_cache_dtype is fp8
  704. if self.kv_cache_dtype.startswith("fp8"):
  705. torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
  706. self.kv_cache_dtype)
  707. kv_cache = kv_cache.view(torch_dtype)
  708. query = query.contiguous(
  709. ) # Flashinfer requires query to be contiguous
  710. if prefill_meta := attn_metadata.prefill_metadata:
  711. # We will use flash attention for prefill
  712. # when kv_cache is not provided.
  713. # This happens when aphrodite runs the profiling to
  714. # determine the number of blocks.
  715. if kv_cache is None:
  716. output = torch.ops.aphrodite.flash_attn_varlen_func(
  717. q=query,
  718. k=key,
  719. v=value,
  720. cu_seqlens_q=prefill_meta.seq_start_loc,
  721. cu_seqlens_k=prefill_meta.seq_start_loc,
  722. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  723. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  724. softmax_scale=self.scale,
  725. causal=True,
  726. window_size=self.sliding_window,
  727. alibi_slopes=self.alibi_slopes,
  728. )
  729. else:
  730. assert prefill_meta is not None
  731. assert prefill_meta.prefill_wrapper is not None
  732. output = prefill_meta.prefill_wrapper.forward(
  733. query,
  734. kv_cache,
  735. logits_soft_cap=self.logits_soft_cap,
  736. causal=True)
  737. else:
  738. assert attn_metadata.decode_metadata is not None
  739. assert attn_metadata.decode_metadata.decode_wrapper is not None
  740. output = attn_metadata.decode_metadata.decode_wrapper.forward(
  741. query,
  742. kv_cache,
  743. sm_scale=self.scale,
  744. logits_soft_cap=self.logits_soft_cap,
  745. k_scale=k_scale,
  746. v_scale=v_scale)
  747. return output.view(num_tokens, hidden_size)