flashinfer.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. from contextlib import contextmanager
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
  4. try:
  5. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  6. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  7. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  8. import aphrodite.attention.backends.flash_attn # noqa
  9. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  10. except ImportError:
  11. BatchDecodeWithPagedKVCacheWrapper = None
  12. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  13. BatchPrefillWithPagedKVCacheWrapper = None
  14. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  15. import torch
  16. from aphrodite import _custom_ops as ops
  17. from aphrodite.attention.backends.abstract import (AttentionBackend,
  18. AttentionImpl,
  19. AttentionMetadata,
  20. AttentionMetadataBuilder,
  21. AttentionState,
  22. AttentionType)
  23. from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
  24. compute_slot_mapping,
  25. compute_slot_mapping_start_idx,
  26. is_block_tables_empty)
  27. from aphrodite.attention.ops.paged_attn import PagedAttention
  28. from aphrodite.common.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
  29. make_tensor_with_pad)
  30. if TYPE_CHECKING:
  31. from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
  32. class FlashInferBackend(AttentionBackend):
  33. @staticmethod
  34. def get_name() -> str:
  35. return "flashinfer"
  36. @staticmethod
  37. def get_impl_cls() -> Type["FlashInferImpl"]:
  38. return FlashInferImpl
  39. @staticmethod
  40. def get_metadata_cls() -> Type["AttentionMetadata"]:
  41. return FlashInferMetadata
  42. @staticmethod
  43. def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
  44. return FlashInferMetadataBuilder
  45. @staticmethod
  46. def get_state_cls() -> Type["FlashInferState"]:
  47. return FlashInferState
  48. @staticmethod
  49. def get_kv_cache_shape(
  50. num_blocks: int,
  51. block_size: int,
  52. num_kv_heads: int,
  53. head_size: int,
  54. ) -> Tuple[int, ...]:
  55. return (num_blocks, 2, block_size, num_kv_heads, head_size)
  56. @staticmethod
  57. def swap_blocks(
  58. src_kv_cache: torch.Tensor,
  59. dst_kv_cache: torch.Tensor,
  60. src_to_dst: torch.Tensor,
  61. ) -> None:
  62. PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
  63. @staticmethod
  64. def copy_blocks(
  65. kv_caches: List[torch.Tensor],
  66. src_to_dists: torch.Tensor,
  67. ) -> None:
  68. PagedAttention.copy_blocks(kv_caches, src_to_dists)
  69. @staticmethod
  70. def get_supported_head_sizes() -> List[int]:
  71. return [64, 128, 256]
  72. class FlashInferState(AttentionState):
  73. def __init__(self, runner):
  74. self.runner = runner
  75. self._is_graph_capturing = False
  76. self._workspace_buffer = None
  77. self._decode_wrapper = None
  78. self._prefill_wrapper = None
  79. def _get_workspace_buffer(self):
  80. if self._workspace_buffer is None:
  81. self._workspace_buffer = torch.empty(
  82. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  83. dtype=torch.uint8,
  84. device=self.runner.device)
  85. return self._workspace_buffer
  86. def _get_prefill_wrapper(self):
  87. if self._prefill_wrapper is None:
  88. self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
  89. self._get_workspace_buffer(), "NHD")
  90. return self._prefill_wrapper
  91. def _get_decode_wrapper(self):
  92. if self._decode_wrapper is None:
  93. num_qo_heads = (self.runner.model_config.get_num_attention_heads(
  94. self.runner.parallel_config))
  95. num_kv_heads = self.runner.model_config.get_num_kv_heads(
  96. self.runner.parallel_config)
  97. use_tensor_cores = num_qo_heads // num_kv_heads >= 4
  98. self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
  99. self._get_workspace_buffer(),
  100. "NHD",
  101. use_tensor_cores=use_tensor_cores)
  102. return self._decode_wrapper
  103. @contextmanager
  104. def graph_capture(self, max_batch_size: int):
  105. self._is_graph_capturing = True
  106. self._graph_decode_wrapper = None
  107. self._graph_slot_mapping = torch.full((max_batch_size, ),
  108. PAD_SLOT_ID,
  109. dtype=torch.long,
  110. device=self.runner.device)
  111. self._graph_seq_lens = torch.ones(max_batch_size,
  112. dtype=torch.int32,
  113. device=self.runner.device)
  114. self._graph_block_tables = torch.from_numpy(
  115. self.runner.graph_block_tables).to(device=self.runner.device)
  116. self._graph_decode_workspace_buffer = self._get_workspace_buffer()
  117. self._graph_indices_buffer = torch.empty(
  118. max_batch_size * self.runner.cache_config.num_gpu_blocks,
  119. dtype=torch.int32,
  120. device=self.runner.device)
  121. self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
  122. dtype=torch.int32,
  123. device=self.runner.device)
  124. self._graph_last_page_len_buffer = torch.empty(
  125. max_batch_size, dtype=torch.int32, device=self.runner.device)
  126. yield
  127. self._is_graph_capturing = False
  128. del self._graph_slot_mapping
  129. del self._graph_seq_lens
  130. del self._graph_block_tables
  131. del self._graph_decode_workspace_buffer
  132. del self._graph_indices_buffer
  133. del self._graph_indptr_buffer
  134. del self._graph_last_page_len_buffer
  135. del self._graph_decode_wrapper
  136. def graph_clone(self, batch_size: int):
  137. assert self._is_graph_capturing
  138. state = self.__class__(self.runner)
  139. state._workspace_buffer = self._graph_decode_workspace_buffer
  140. state._decode_wrapper = self._graph_decode_wrapper
  141. state._prefill_wrapper = self._get_prefill_wrapper()
  142. return state
  143. def graph_capture_get_metadata_for_batch(self, batch_size: int):
  144. assert self._is_graph_capturing
  145. _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
  146. _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
  147. num_qo_heads = (self.runner.model_config.get_num_attention_heads(
  148. self.runner.parallel_config))
  149. num_kv_heads = self.runner.model_config.get_num_kv_heads(
  150. self.runner.parallel_config)
  151. use_tensor_cores = num_qo_heads // num_kv_heads >= 4
  152. self._graph_decode_wrapper = \
  153. CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
  154. self._graph_decode_workspace_buffer, _indptr_buffer,
  155. self._graph_indices_buffer, _last_page_len_buffer, "NHD",
  156. use_tensor_cores)
  157. kv_cache_dtype = get_kv_cache_torch_dtype(
  158. self.runner.kv_cache_dtype, self.runner.model_config.dtype)
  159. paged_kv_indptr_tensor_host = torch.arange(0,
  160. batch_size + 1,
  161. dtype=torch.int32)
  162. paged_kv_indices_tensor_host = torch.arange(0,
  163. batch_size,
  164. dtype=torch.int32)
  165. paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
  166. self.runner.block_size,
  167. dtype=torch.int32)
  168. query_start_loc_host = torch.arange(0,
  169. batch_size + 1,
  170. dtype=torch.int32)
  171. attn_metadata = self.runner.attn_backend.make_metadata(
  172. num_prefills=0,
  173. slot_mapping=self._graph_slot_mapping[:batch_size],
  174. num_prefill_tokens=0,
  175. num_decode_tokens=batch_size,
  176. max_prefill_seq_len=0,
  177. block_tables=self._graph_block_tables,
  178. paged_kv_indptr=paged_kv_indptr_tensor_host,
  179. paged_kv_indices=paged_kv_indices_tensor_host,
  180. paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
  181. num_qo_heads=num_qo_heads,
  182. num_kv_heads=num_kv_heads,
  183. head_dim=self.runner.model_config.get_head_size(),
  184. page_size=self.runner.block_size,
  185. seq_start_loc=None,
  186. query_start_loc=query_start_loc_host,
  187. device=self.runner.device,
  188. data_type=kv_cache_dtype,
  189. use_cuda_graph=True,
  190. decode_wrapper=self._graph_decode_wrapper,
  191. prefill_wrapper=None)
  192. attn_metadata.begin_forward()
  193. return attn_metadata
  194. def get_graph_input_buffers(self, attn_metadata):
  195. return {
  196. "slot_mapping": attn_metadata.slot_mapping,
  197. }
  198. def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
  199. return
  200. def begin_forward(self, model_input):
  201. assert not self._is_graph_capturing
  202. state = self
  203. if model_input.attn_metadata.use_cuda_graph:
  204. batch_size = model_input.input_tokens.shape[0]
  205. state = (self.runner.graph_runners[model_input.virtual_engine]
  206. [batch_size].attn_state)
  207. model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
  208. )
  209. model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
  210. model_input.attn_metadata.begin_forward()
  211. @dataclass
  212. class FlashInferMetadata(AttentionMetadata):
  213. # Maximum sequence length among prefill batch. 0 if there are decoding
  214. # requests only.
  215. max_prefill_seq_len: int
  216. use_cuda_graph: bool = True
  217. prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
  218. decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
  219. # Metadata for the prefill stage
  220. seq_start_loc: Optional[torch.Tensor] = None
  221. query_start_loc: Optional[torch.Tensor] = None
  222. block_tables: Optional[torch.Tensor] = None
  223. # An example for paged_kv_indices, paged_kv_indptr:
  224. # request 1, page indices [0, 5, 8]
  225. # request 2, page indices [1, 6, 7]
  226. # request 3, page indices [3, 4]
  227. # paged_kv_indices is a concatenation of page indices of all requests:
  228. # [0, 5, 8, 1, 6, 7, 3, 4]
  229. # paged_kv_indptr is used to index into paged_kv_indices:
  230. # [0, 3, 6, 8]
  231. # The indptr of the paged kv cache, shape: [batch_size + 1]
  232. paged_kv_indptr: Optional[torch.Tensor] = None
  233. # The page indices of the paged kv cache
  234. paged_kv_indices: Optional[torch.Tensor] = None
  235. # The number of entries in the last page of each request in
  236. # the paged kv cache, shape: [batch_size]
  237. paged_kv_last_page_len: Optional[torch.Tensor] = None
  238. # The number of query/output heads
  239. num_qo_heads: Optional[int] = None
  240. # The number of key/value heads
  241. num_kv_heads: Optional[int] = None
  242. # The dimension of the attention heads
  243. head_dim: Optional[int] = None
  244. # Block size of Aphrodite
  245. page_size: Optional[int] = None
  246. # The data type of the paged kv cache
  247. data_type: torch.dtype = None
  248. device: torch.device = torch.device("cuda")
  249. is_profile_run: bool = False
  250. def __post_init__(self):
  251. # Refer to
  252. # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
  253. supported_head_sizes = FlashInferBackend.get_supported_head_sizes()
  254. if self.head_dim is not None and self.head_dim \
  255. not in supported_head_sizes:
  256. raise ValueError(
  257. f"Only {supported_head_sizes} are supported for head_dim,",
  258. f"received {self.head_dim}.")
  259. def begin_forward(self):
  260. if self.num_prefill_tokens > 0:
  261. if self.paged_kv_indices is None:
  262. return
  263. assert self.prefill_wrapper is not None
  264. assert self.query_start_loc is not None
  265. assert self.paged_kv_indices is not None
  266. assert self.paged_kv_indptr is not None
  267. assert self.paged_kv_last_page_len is not None
  268. batch_size = self.query_start_loc.shape[0] - 1
  269. assert batch_size >= 0
  270. # We will use flash attention for profiling to
  271. # determine the number of blocks. Therefore,
  272. # we don't need to prepare the input for flashinfer for profile run.
  273. if not self.is_profile_run:
  274. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  275. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  276. self.device)
  277. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  278. self.prefill_wrapper.end_forward()
  279. self.prefill_wrapper.begin_forward(
  280. self.query_start_loc, self.paged_kv_indptr,
  281. self.paged_kv_indices, self.paged_kv_last_page_len,
  282. self.num_qo_heads, self.num_kv_heads, self.head_dim,
  283. self.page_size)
  284. else:
  285. if not self.use_cuda_graph:
  286. assert self.paged_kv_indices is not None
  287. assert self.paged_kv_indptr is not None
  288. assert self.paged_kv_last_page_len is not None
  289. self.paged_kv_indices = self.paged_kv_indices.to(self.device)
  290. self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
  291. self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
  292. self.device)
  293. assert self.decode_wrapper is not None
  294. self.decode_wrapper.end_forward()
  295. self.decode_wrapper.begin_forward(
  296. self.paged_kv_indptr,
  297. self.paged_kv_indices,
  298. self.paged_kv_last_page_len,
  299. self.num_qo_heads,
  300. self.num_kv_heads,
  301. self.head_dim,
  302. self.page_size,
  303. # Disable flashinfer's pos encoding and use Aphrodite's rope.
  304. pos_encoding_mode="NONE",
  305. data_type=self.data_type)
  306. def asdict_zerocopy(self,
  307. skip_fields: Optional[Set[str]] = None
  308. ) -> Dict[str, Any]:
  309. if skip_fields is None:
  310. skip_fields = set()
  311. # We need to skip the prefill/decode_wrapper field since it cannot be
  312. # broadcasted with nccl when TP is enabled.
  313. skip_fields.add('prefill_wrapper')
  314. skip_fields.add('decode_wrapper')
  315. return super().asdict_zerocopy(skip_fields)
  316. @property
  317. def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
  318. # Currently chunked prefill is not supported
  319. if self.num_decode_tokens == 0:
  320. assert self.num_prefills > 0
  321. return self
  322. return None
  323. @property
  324. def decode_metadata(self) -> Optional["FlashInferMetadata"]:
  325. # Currently chunked prefill is not supported
  326. if self.num_prefills > 0:
  327. assert self.num_decode_tokens == 0
  328. return None
  329. return self
  330. class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
  331. def __init__(self, input_builder: "ModelInputForGPUBuilder"):
  332. self.slot_mapping: List[int] = []
  333. self.prefill_seq_lens: List[int] = []
  334. self.context_lens: List[int] = []
  335. self.block_tables: List[List[int]] = []
  336. self.curr_seq_lens: List[int] = []
  337. self.num_prefills = 0
  338. self.num_prefill_tokens = 0
  339. self.num_decode_tokens = 0
  340. self.input_builder = input_builder
  341. self.runner = input_builder.runner
  342. self.sliding_window = input_builder.sliding_window
  343. self.block_size = input_builder.block_size
  344. self.use_v2_block_manager = (
  345. input_builder.scheduler_config.use_v2_block_manager)
  346. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  347. # for the precise definition of the following fields.
  348. # An example:
  349. # request 1, page indices [0, 5, 8]
  350. # request 2, page indices [1, 6, 7]
  351. # request 3, page indices [3, 4]
  352. # paged_kv_indices is a concatenation of page indices of all requests:
  353. # [0, 5, 8, 1, 6, 7, 3, 4]
  354. # paged_kv_indptr is used to index into paged_kv_indices:
  355. # [0, 3, 6, 8]
  356. self.paged_kv_indices: List[int] = []
  357. # 0 at the beginning of paged_kv_indptr indicates the start of the
  358. # first request’s page indices in the paged_kv_indices list.
  359. self.paged_kv_indptr: List[int] = [0]
  360. # paged_kv_last_page_len is the length of the last page of each request
  361. self.paged_kv_last_page_len: List[int] = []
  362. self.is_profile_run: bool = False
  363. def _add_seq_group(
  364. self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
  365. chunked_prefill_enabled: bool):
  366. """Add a sequence group to the metadata. Specifically update/append
  367. 1. context length.
  368. 2. block table.
  369. 3. slot mapping.
  370. """
  371. is_prompt = inter_data.is_prompt
  372. block_tables = inter_data.block_tables
  373. computed_block_nums = inter_data.computed_block_nums
  374. for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
  375. curr_sliding_window_block) in zip(
  376. inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
  377. inter_data.orig_seq_lens, inter_data.seq_lens,
  378. inter_data.query_lens, inter_data.context_lens,
  379. inter_data.curr_sliding_window_blocks):
  380. self.context_lens.append(context_len)
  381. if is_prompt:
  382. self.num_prefills += 1
  383. self.num_prefill_tokens += token_len
  384. self.prefill_seq_lens.append(seq_len)
  385. else:
  386. assert query_len == 1, (
  387. "seq_len: {}, context_len: {}, query_len: {}".format(
  388. seq_len, context_len, query_len))
  389. self.num_decode_tokens += query_len
  390. self.curr_seq_lens.append(curr_seq_len)
  391. # Compute block table.
  392. # TODO: Combine chunked prefill and prefix caching by
  393. # only allowing multiple of block_size chunk size.
  394. # NOTE: This only works for oooooooxxx style attention.
  395. block_table = []
  396. if inter_data.prefix_cache_hit:
  397. block_table = computed_block_nums
  398. elif ((chunked_prefill_enabled or not is_prompt)
  399. and block_tables is not None):
  400. block_table = block_tables[seq_id][-curr_sliding_window_block:]
  401. self.block_tables.append(block_table)
  402. is_profile_run = is_block_tables_empty(block_tables)
  403. # Compute slot mapping.
  404. start_idx = compute_slot_mapping_start_idx(
  405. is_prompt, query_len, context_len, self.sliding_window,
  406. self.use_v2_block_manager)
  407. compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
  408. seq_len, context_len, start_idx,
  409. self.block_size, inter_data.block_tables)
  410. # It is not necessary to add paged_kv_indices, paged_kv_indptr,
  411. # and paged_kv_last_page_len for profile run because we will
  412. # create dummy inputs.
  413. if is_profile_run:
  414. self.is_profile_run = is_profile_run
  415. return
  416. block_table = block_tables[seq_id]
  417. self._update_paged_kv_tensors(block_table, seq_len)
  418. def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
  419. # Get the number of valid blocks based on sequence length.
  420. # If seq_len = 16, block_size = 16,
  421. # block_table_bound is 1 with 1 valid block.
  422. # If seq_len = 15, block_size = 16,
  423. # block_table_bound is 0 + 1 with 1 valid block.
  424. block_table_bound = seq_len // self.block_size + 1 \
  425. if seq_len % self.block_size != 0 \
  426. else seq_len // self.block_size
  427. self.paged_kv_indices.extend(block_table[:block_table_bound])
  428. self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
  429. block_table_bound)
  430. last_page_len = seq_len % self.block_size
  431. if last_page_len == 0:
  432. last_page_len = self.block_size
  433. self.paged_kv_last_page_len.append(last_page_len)
  434. def build(self, seq_lens: List[int], query_lens: List[int],
  435. cuda_graph_pad_size: int, batch_size: int):
  436. """Build attention metadata with on-device tensors.
  437. Args:
  438. seq_lens: The maybe padded sequence lengths of the input sequences.
  439. query_lens: The query lengths of the input sequences.
  440. cuda_graph_pad_size: The padding size for cuda graph.
  441. -1 if cuda graph is not used.
  442. batch_size: The maybe padded batch size.
  443. """
  444. for inter_data in self.input_builder.inter_data_list:
  445. self._add_seq_group(inter_data,
  446. self.input_builder.chunked_prefill_enabled)
  447. device = self.runner.device
  448. use_captured_graph = cuda_graph_pad_size != -1
  449. max_query_len = max(query_lens)
  450. max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
  451. num_decode_tokens = self.num_decode_tokens
  452. if use_captured_graph:
  453. self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
  454. self.block_tables.extend([] * cuda_graph_pad_size)
  455. num_decode_tokens = batch_size
  456. # The shape of graph_block_tables is
  457. # [max batch size, max context len // block size].
  458. input_block_tables = self.runner.graph_block_tables[:batch_size]
  459. for i, block_table in enumerate(self.block_tables):
  460. if block_table:
  461. input_block_tables[i, :len(block_table)] = block_table
  462. block_tables = torch.from_numpy(input_block_tables).to(
  463. device, non_blocking=True)
  464. last_paged_kv_indptr = self.paged_kv_indptr[-1]
  465. self.paged_kv_indptr.extend([last_paged_kv_indptr] *
  466. cuda_graph_pad_size)
  467. self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
  468. else:
  469. block_tables = make_tensor_with_pad(
  470. self.block_tables,
  471. pad=0,
  472. dtype=torch.int,
  473. device=device,
  474. )
  475. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  476. assert device is not None
  477. seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
  478. self.runner.pin_memory)
  479. query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device,
  480. self.runner.pin_memory)
  481. slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
  482. device, self.runner.pin_memory)
  483. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  484. dtype=torch.int32,
  485. device=device)
  486. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  487. dtype=torch.int32,
  488. device=device)
  489. torch.cumsum(seq_lens_tensor,
  490. dim=0,
  491. dtype=seq_start_loc.dtype,
  492. out=seq_start_loc[1:])
  493. torch.cumsum(query_lens_tensor,
  494. dim=0,
  495. dtype=query_start_loc.dtype,
  496. out=query_start_loc[1:])
  497. if len(self.paged_kv_indptr) > 0:
  498. paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
  499. device="cpu",
  500. dtype=torch.int)
  501. paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
  502. device="cpu",
  503. dtype=torch.int)
  504. paged_kv_last_page_len_tensor = torch.tensor(
  505. self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
  506. else:
  507. paged_kv_indices_tensor = None
  508. paged_kv_indptr_tensor = None
  509. paged_kv_last_page_len_tensor = None
  510. kv_cache_dtype = get_kv_cache_torch_dtype(
  511. self.runner.kv_cache_dtype, self.runner.model_config.dtype)
  512. return FlashInferMetadata(
  513. num_prefills=self.num_prefills,
  514. slot_mapping=slot_mapping_tensor,
  515. num_prefill_tokens=self.num_prefill_tokens,
  516. num_decode_tokens=num_decode_tokens,
  517. max_prefill_seq_len=max_prefill_seq_len,
  518. block_tables=block_tables,
  519. paged_kv_indptr=paged_kv_indptr_tensor,
  520. paged_kv_indices=paged_kv_indices_tensor,
  521. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  522. num_qo_heads=self.runner.model_config.get_num_attention_heads(
  523. self.runner.parallel_config),
  524. num_kv_heads=self.runner.model_config.get_num_kv_heads(
  525. self.runner.parallel_config),
  526. head_dim=self.runner.model_config.get_head_size(),
  527. page_size=self.block_size,
  528. seq_start_loc=seq_start_loc,
  529. query_start_loc=query_start_loc,
  530. device=device,
  531. data_type=kv_cache_dtype,
  532. use_cuda_graph=use_captured_graph,
  533. is_profile_run=self.is_profile_run)
  534. class FlashInferImpl(AttentionImpl):
  535. def __init__(
  536. self,
  537. num_heads: int,
  538. head_size: int,
  539. scale: float,
  540. num_kv_heads: int,
  541. alibi_slopes: Optional[List[float]],
  542. sliding_window: Optional[int],
  543. kv_cache_dtype: str,
  544. blocksparse_params: Optional[Dict[str, Any]] = None,
  545. logits_soft_cap: Optional[float] = None,
  546. ) -> None:
  547. assert blocksparse_params is None, ValueError(
  548. "FlashInfer does not support block-sparse attention.")
  549. self.num_heads = num_heads
  550. self.head_size = head_size
  551. self.scale = float(scale)
  552. self.num_kv_heads = num_kv_heads
  553. if alibi_slopes is not None:
  554. alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
  555. self.alibi_slopes = alibi_slopes
  556. if sliding_window is not None:
  557. raise ValueError("Sliding window is not supported in FlashInfer.")
  558. self.sliding_window = (-1, -1)
  559. self.kv_cache_dtype = kv_cache_dtype
  560. self.logits_soft_cap = logits_soft_cap
  561. assert self.num_heads % self.num_kv_heads == 0
  562. self.num_queries_per_kv = self.num_heads // self.num_kv_heads
  563. def forward(
  564. self,
  565. query: torch.Tensor,
  566. key: torch.Tensor,
  567. value: torch.Tensor,
  568. kv_cache: Optional[torch.Tensor],
  569. attn_metadata: FlashInferMetadata,
  570. k_scale: float = 1.0,
  571. v_scale: float = 1.0,
  572. attn_type: AttentionType = AttentionType.DECODER,
  573. ) -> torch.Tensor:
  574. assert k_scale == 1.0 and v_scale == 1.0, (
  575. "key/v_scale is not supported in FlashInfer.")
  576. if attn_type != AttentionType.DECODER:
  577. raise NotImplementedError("Encoder self-attention and "
  578. "encoder/decoder cross-attention "
  579. "are not implemented for "
  580. "FlashInferImpl")
  581. num_tokens, hidden_size = query.shape
  582. query = query.view(-1, self.num_heads, self.head_size)
  583. key = key.view(-1, self.num_kv_heads, self.head_size)
  584. value = value.view(-1, self.num_kv_heads, self.head_size)
  585. if attn_metadata.num_prefill_tokens > 0:
  586. assert attn_metadata.num_decode_tokens == 0, (
  587. "Chunked prefill is not supported with flashinfer yet.")
  588. if attn_metadata.num_decode_tokens > 0:
  589. assert attn_metadata.num_prefill_tokens == 0, (
  590. "Chunked prefill is not supported with flashinfer yet.")
  591. if kv_cache is not None:
  592. # Use the same reshape and cache kernel as flash attention.
  593. ops.reshape_and_cache_flash(
  594. key,
  595. value,
  596. kv_cache[:, 0],
  597. kv_cache[:, 1],
  598. attn_metadata.slot_mapping.flatten(),
  599. self.kv_cache_dtype,
  600. k_scale,
  601. v_scale,
  602. )
  603. query = query.contiguous(
  604. ) # Flashinfer requires query to be contiguous
  605. if prefill_meta := attn_metadata.prefill_metadata:
  606. # We will use flash attention for prefill
  607. # when kv_cache is not provided.
  608. # This happens when Aphrodite runs the profiling to
  609. # determine the number of blocks.
  610. if kv_cache is None:
  611. output = torch.ops.aphrodite.flash_attn_varlen_func(
  612. q=query,
  613. k=key,
  614. v=value,
  615. cu_seqlens_q=prefill_meta.seq_start_loc,
  616. cu_seqlens_k=prefill_meta.seq_start_loc,
  617. max_seqlen_q=prefill_meta.max_prefill_seq_len,
  618. max_seqlen_k=prefill_meta.max_prefill_seq_len,
  619. softmax_scale=self.scale,
  620. causal=True,
  621. window_size=self.sliding_window,
  622. alibi_slopes=self.alibi_slopes,
  623. )
  624. else:
  625. assert prefill_meta is not None
  626. assert prefill_meta.prefill_wrapper is not None
  627. output = prefill_meta.prefill_wrapper.forward(
  628. query,
  629. kv_cache,
  630. logits_soft_cap=self.logits_soft_cap,
  631. causal=True)
  632. else:
  633. assert attn_metadata.decode_metadata is not None
  634. assert attn_metadata.decode_metadata.decode_wrapper is not None
  635. output = attn_metadata.decode_metadata.decode_wrapper.forward(
  636. query,
  637. kv_cache,
  638. sm_scale=self.scale,
  639. logits_soft_cap=self.logits_soft_cap)
  640. return output.view(num_tokens, hidden_size)