model_runner.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585
  1. import dataclasses
  2. import gc
  3. import time
  4. import warnings
  5. from collections import defaultdict
  6. from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
  7. Tuple, Type, TypeVar, Union)
  8. import numpy as np
  9. import torch
  10. import torch.distributed
  11. import torch.nn as nn
  12. from loguru import logger
  13. try:
  14. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  15. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  16. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  17. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  18. except ImportError:
  19. BatchDecodeWithPagedKVCacheWrapper = None
  20. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  21. BatchPrefillWithPagedKVCacheWrapper = None
  22. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  23. from aphrodite.attention import AttentionMetadata, get_attn_backend
  24. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  25. LoRAConfig, ModelConfig, MultiModalConfig,
  26. ParallelConfig, PromptAdapterConfig,
  27. SchedulerConfig)
  28. from aphrodite.common.sampling_params import SamplingParams
  29. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  30. SequenceGroupMetadata)
  31. from aphrodite.common.utils import (CudaMemoryProfiler,
  32. get_kv_cache_torch_dtype, is_hip,
  33. is_pin_memory_available,
  34. make_tensor_with_pad)
  35. from aphrodite.distributed import get_pp_group
  36. from aphrodite.distributed.parallel_state import (
  37. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
  38. graph_capture)
  39. from aphrodite.inputs import INPUT_REGISTRY
  40. from aphrodite.lora.layers import LoRAMapping
  41. from aphrodite.lora.request import LoRARequest
  42. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  43. from aphrodite.modeling import SamplingMetadata
  44. from aphrodite.modeling.model_loader import get_model
  45. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  46. from aphrodite.modeling.models.interfaces import supports_lora, supports_vision
  47. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
  48. MultiModalInputs)
  49. from aphrodite.prompt_adapter.layers import PromptAdapterMapping
  50. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  51. from aphrodite.prompt_adapter.worker_manager import \
  52. LRUCacheWorkerPromptAdapterManager
  53. from aphrodite.task_handler.model_runner_base import (
  54. ModelRunnerBase, ModelRunnerInputBase,
  55. _add_attn_metadata_broadcastable_dict,
  56. _add_sampling_metadata_broadcastable_dict,
  57. _init_attn_metadata_from_tensor_dict,
  58. _init_sampling_metadata_from_tensor_dict)
  59. if TYPE_CHECKING:
  60. from aphrodite.attention.backends.abstract import AttentionBackend
  61. _PAD_SLOT_ID = -1
  62. LORA_WARMUP_RANK = 8
  63. _BATCH_SIZE_ALIGNMENT = 8
  64. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  65. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  66. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  67. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  68. ]
  69. _NUM_WARMUP_ITERS = 2
  70. TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
  71. @dataclasses.dataclass(frozen=True)
  72. class ModelInputForGPU(ModelRunnerInputBase):
  73. """
  74. This base class contains metadata needed for the base model forward pass
  75. but not metadata for possible additional steps, e.g., sampling. Model
  76. runners that run additional steps should subclass this method to add
  77. additional fields.
  78. """
  79. input_tokens: Optional[torch.Tensor] = None
  80. input_positions: Optional[torch.Tensor] = None
  81. seq_lens: Optional[List[int]] = None
  82. query_lens: Optional[List[int]] = None
  83. lora_mapping: Optional["LoRAMapping"] = None
  84. lora_requests: Optional[Set[LoRARequest]] = None
  85. attn_metadata: Optional["AttentionMetadata"] = None
  86. prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
  87. prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
  88. multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
  89. request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
  90. finished_requests_ids: Optional[List[str]] = None
  91. virtual_engine: int = 0
  92. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  93. tensor_dict = {
  94. "input_tokens": self.input_tokens,
  95. "input_positions": self.input_positions,
  96. "lora_requests": self.lora_requests,
  97. "lora_mapping": self.lora_mapping,
  98. "multi_modal_kwargs": self.multi_modal_kwargs,
  99. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  100. "prompt_adapter_requests": self.prompt_adapter_requests,
  101. "virtual_engine": self.virtual_engine,
  102. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  103. "finished_requests_ids": self.finished_requests_ids,
  104. }
  105. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  106. return tensor_dict
  107. @classmethod
  108. def from_broadcasted_tensor_dict(
  109. cls: Type[TModelInputForGPU],
  110. tensor_dict: Dict[str, Any],
  111. attn_backend: Optional["AttentionBackend"] = None,
  112. ) -> TModelInputForGPU:
  113. if attn_backend is not None:
  114. tensor_dict = _init_attn_metadata_from_tensor_dict(
  115. attn_backend, tensor_dict)
  116. return cls(**tensor_dict)
  117. @dataclasses.dataclass(frozen=True)
  118. class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
  119. """
  120. Used by the ModelRunner.
  121. """
  122. sampling_metadata: Optional["SamplingMetadata"] = None
  123. # Used for speculative decoding. We do not broadcast it because it is only
  124. # used by the driver worker.
  125. is_prompt: Optional[bool] = None
  126. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  127. tensor_dict = {
  128. "input_tokens": self.input_tokens,
  129. "input_positions": self.input_positions,
  130. "lora_requests": self.lora_requests,
  131. "lora_mapping": self.lora_mapping,
  132. "multi_modal_kwargs": self.multi_modal_kwargs,
  133. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  134. "prompt_adapter_requests": self.prompt_adapter_requests,
  135. "virtual_engine": self.virtual_engine,
  136. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  137. "finished_requests_ids": self.finished_requests_ids,
  138. }
  139. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  140. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  141. self.sampling_metadata)
  142. return tensor_dict
  143. @classmethod
  144. def from_broadcasted_tensor_dict(
  145. cls,
  146. tensor_dict: Dict[str, Any],
  147. attn_backend: Optional["AttentionBackend"] = None,
  148. ) -> "ModelInputForGPUWithSamplingMetadata":
  149. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  150. if attn_backend is not None:
  151. tensor_dict = _init_attn_metadata_from_tensor_dict(
  152. attn_backend, tensor_dict)
  153. return cls(**tensor_dict)
  154. class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
  155. """
  156. Helper class for shared methods between GPU model runners.
  157. """
  158. _model_input_cls: Type[TModelInputForGPU]
  159. def __init__(
  160. self,
  161. model_config: ModelConfig,
  162. parallel_config: ParallelConfig,
  163. scheduler_config: SchedulerConfig,
  164. device_config: DeviceConfig,
  165. cache_config: CacheConfig,
  166. load_config: LoadConfig,
  167. lora_config: Optional[LoRAConfig],
  168. kv_cache_dtype: Optional[str] = "auto",
  169. is_driver_worker: bool = False,
  170. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  171. multimodal_config: Optional[MultiModalConfig] = None,
  172. return_hidden_states: bool = False,
  173. ):
  174. self.model_config = model_config
  175. self.parallel_config = parallel_config
  176. self.scheduler_config = scheduler_config
  177. self.device_config = device_config
  178. self.cache_config = cache_config
  179. self.lora_config = lora_config
  180. self.load_config = load_config
  181. self.is_driver_worker = is_driver_worker
  182. self.prompt_adapter_config = prompt_adapter_config
  183. self.multimodal_config = multimodal_config
  184. self.return_hidden_states = return_hidden_states
  185. self.device = self.device_config.device
  186. self.pin_memory = is_pin_memory_available()
  187. self.kv_cache_dtype = kv_cache_dtype
  188. self.sliding_window = model_config.get_sliding_window()
  189. self.block_size = cache_config.block_size
  190. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  191. self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
  192. {} for _ in range(self.parallel_config.pipeline_parallel_size)
  193. ]
  194. self.graph_memory_pool: Optional[Tuple[
  195. int, int]] = None # Set during graph capture.
  196. self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
  197. parallel_config)
  198. # When using CUDA graph, the input block tables must be padded to
  199. # max_seq_len_to_capture. However, creating the block table in
  200. # Python can be expensive. To optimize this, we cache the block table
  201. # in numpy and only copy the actual input content at every iteration.
  202. # The shape of the cached block table will be
  203. # (max batch size to capture, max context len to capture / block size).
  204. self.graph_block_tables = np.zeros(
  205. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  206. dtype=np.int32)
  207. num_attn_heads = self.model_config.get_num_attention_heads(
  208. self.parallel_config)
  209. self.attn_backend = get_attn_backend(
  210. num_attn_heads,
  211. self.model_config.get_head_size(),
  212. self.model_config.get_num_kv_heads(self.parallel_config),
  213. self.model_config.get_sliding_window(),
  214. self.model_config.dtype,
  215. self.kv_cache_dtype,
  216. self.block_size,
  217. ) if num_attn_heads else None
  218. # Multi-modal data support
  219. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  220. .create_input_mapper(self.model_config)
  221. # Lazy initialization
  222. self.model: nn.Module # Set after load_model
  223. # Set after load_model.
  224. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  225. self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
  226. self.flashinfer_decode_workspace_buffer = None
  227. self.flashinfer_decode_wrapper = None
  228. self.flashinfer_prefill_workspace_buffer = None
  229. self.flashinfer_prefill_wrapper = None
  230. def load_model(self) -> None:
  231. with CudaMemoryProfiler() as m:
  232. # measure the time it takes to load the model
  233. start_time = time.time()
  234. self.model = get_model(model_config=self.model_config,
  235. device_config=self.device_config,
  236. load_config=self.load_config,
  237. lora_config=self.lora_config,
  238. multimodal_config=self.multimodal_config,
  239. parallel_config=self.parallel_config,
  240. scheduler_config=self.scheduler_config,
  241. cache_config=self.cache_config)
  242. end_time = time.time()
  243. self.model_memory_usage = m.consumed_memory
  244. tp = get_tensor_model_parallel_world_size()
  245. rank = get_tensor_model_parallel_rank()
  246. total_time = end_time - start_time
  247. if tp > 1:
  248. logger.info(
  249. f"Rank {rank}: Model weights loaded in {total_time:.2f} secs.")
  250. if rank == 0:
  251. logger.info(
  252. "Memory usage: "
  253. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} ="
  254. f" {self.model_memory_usage * tp / float(2**30):.2f} GiB")
  255. else:
  256. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  257. logger.info("Memory usage: "
  258. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  259. if self.lora_config:
  260. assert supports_lora(self.model), "Model does not support LoRA"
  261. assert not supports_vision(
  262. self.model
  263. ), "To be tested: vision language model with LoRA settings."
  264. self.lora_manager = LRUCacheWorkerLoRAManager(
  265. self.scheduler_config.max_num_seqs,
  266. self.scheduler_config.max_num_batched_tokens,
  267. self.vocab_size,
  268. self.lora_config,
  269. self.device,
  270. self.model.embedding_modules,
  271. self.model.embedding_padding_modules,
  272. max_position_embeddings=self.model.config.
  273. max_position_embeddings,
  274. )
  275. self.model = self.lora_manager.create_lora_manager(self.model)
  276. if self.prompt_adapter_config:
  277. self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
  278. self.scheduler_config.max_num_seqs,
  279. self.scheduler_config.max_num_batched_tokens, self.device,
  280. self.prompt_adapter_config)
  281. self.model = (
  282. self.prompt_adapter_manager.create_prompt_adapter_manager(
  283. self.model))
  284. if self.kv_cache_dtype == "fp8" and is_hip():
  285. # Currently only ROCm accepts kv-cache scaling factors
  286. # via quantization_param_path and this will be deprecated
  287. # in the future.
  288. if self.model_config.quantization_param_path is not None:
  289. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  290. warnings.warn(
  291. "Loading kv cache scaling factor from JSON is "
  292. "deprecated and will be removed. Please include "
  293. "kv cache scaling factors in the model checkpoint.",
  294. FutureWarning,
  295. stacklevel=2)
  296. self.model.load_kv_cache_scales(
  297. self.model_config.quantization_param_path)
  298. logger.info(
  299. "Loaded KV cache scaling factors from ",
  300. f"{self.model_config.quantization_param_path}")
  301. else:
  302. raise RuntimeError(
  303. "Using FP8 KV cache and scaling factors provided but "
  304. f"model {self.model.__class__} does not support loading"
  305. " scaling factors.", )
  306. else:
  307. logger.warning(
  308. "Using FP8 KV cache but no scaling factors "
  309. "provided. Defaulting to scaling factors of 1.0. "
  310. "This may lead to less accurate results!")
  311. def save_sharded_state(
  312. self,
  313. path: str,
  314. pattern: Optional[str] = None,
  315. max_size: Optional[int] = None,
  316. ) -> None:
  317. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  318. ShardedStateLoader.save_model(
  319. self.model,
  320. path,
  321. pattern=pattern,
  322. max_size=max_size,
  323. )
  324. def save_tensorized_model(
  325. self,
  326. tensorizer_config: TensorizerConfig,
  327. ) -> None:
  328. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  329. TensorizerLoader.save_model(
  330. self.model,
  331. tensorizer_config=tensorizer_config,
  332. )
  333. def get_max_block_per_batch(self) -> int:
  334. block_size = self.block_size
  335. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  336. def _prepare_model_input_tensors(
  337. self,
  338. seq_group_metadata_list: List[SequenceGroupMetadata],
  339. finished_requests_ids: Optional[List[str]] = None
  340. ) -> TModelInputForGPU:
  341. """Helper method to prepare the model input based on a given sequence
  342. group. Prepares metadata needed for the base model forward pass but not
  343. metadata for possible additional steps, e.g., sampling.
  344. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  345. The result tensors and data structure also batches input in prefill
  346. -> decode order. For example,
  347. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  348. - input_tokens[num_prefill_tokens:] contains decode tokens.
  349. If cuda graph is required, this API automatically pads inputs.
  350. """
  351. input_tokens: List[int] = []
  352. input_positions: List[int] = []
  353. slot_mapping: List[int] = []
  354. lora_index_mapping: List[int] = []
  355. lora_prompt_mapping: List[int] = []
  356. lora_requests: Set[LoRARequest] = set()
  357. prompt_adapter_index_mapping: List[int] = []
  358. prompt_adapter_prompt_mapping: List[int] = []
  359. prompt_adapter_requests: Set[PromptAdapterRequest] = set()
  360. seq_lens: List[int] = []
  361. prefill_seq_lens: List[int] = []
  362. decode_seq_lens: List[int] = []
  363. context_lens: List[int] = []
  364. query_lens: List[int] = []
  365. block_tables: List[List[int]] = []
  366. multi_modal_inputs_list: List[MultiModalInputs] = []
  367. request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
  368. decode_only = True
  369. num_prefills = 0
  370. num_prefill_tokens = 0
  371. num_decode_tokens = 0
  372. # The following fields are only for flashinfer
  373. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  374. # for the precise definition of the following fields.
  375. # An example:
  376. # request 1, page indices [0, 5, 8]
  377. # request 2, page indices [1, 6, 7]
  378. # request 3, page indices [3, 4]
  379. # paged_kv_indices is a concatenation of page indices of all requests:
  380. # [0, 5, 8, 1, 6, 7, 3, 4]
  381. # paged_kv_indptr is used to index into paged_kv_indices:
  382. # [0, 3, 6, 8]
  383. paged_kv_indices: List[int] = []
  384. # 0 at the beginning of paged_kv_indptr indicates the start of the
  385. # first request’s page indices in the paged_kv_indices list.
  386. paged_kv_indptr: List[int] = [0]
  387. # paged_kv_last_page_len is the length of the last page of each request
  388. paged_kv_last_page_len: List[int] = []
  389. if len(seq_group_metadata_list) == 0:
  390. return self._model_input_cls()
  391. if self.sliding_window is not None:
  392. sliding_window_blocks = (self.sliding_window + self.block_size -
  393. 1) // self.block_size
  394. block_aligned_sliding_window = \
  395. sliding_window_blocks * self.block_size
  396. for seq_group_metadata in seq_group_metadata_list:
  397. seq_ids = list(seq_group_metadata.seq_data.keys())
  398. is_prompt = seq_group_metadata.is_prompt
  399. for seq_id in seq_ids:
  400. computed_block_nums = seq_group_metadata.computed_block_nums
  401. if (self.scheduler_config is not None
  402. and self.scheduler_config.chunked_prefill_enabled
  403. and not (computed_block_nums is None
  404. or computed_block_nums == [])):
  405. raise RuntimeError(
  406. "chunked prefill cannot be used with prefix caching "
  407. "now.")
  408. seq_data = seq_group_metadata.seq_data[seq_id]
  409. if is_prompt:
  410. context_len = seq_data.get_num_computed_tokens()
  411. else:
  412. # get_num_computed_tokens is incorrect for spec decoding.
  413. # So, we should have a special logic here.
  414. # TODO: Fix it.
  415. context_len = seq_data.get_len() - 1
  416. seq_len = min(
  417. seq_data.get_len(),
  418. context_len + seq_group_metadata.token_chunk_size)
  419. if is_prompt:
  420. tokens = seq_data.get_token_ids()[context_len:seq_len]
  421. else:
  422. # Optimization. get_token_ids requires the entire copy of
  423. # tokens.
  424. tokens = [seq_data.get_last_token_id()]
  425. # Prefix cache was hit.
  426. # Prefix is not supported with sliding_window
  427. prefix_cache_hit = (computed_block_nums is not None
  428. and len(computed_block_nums) > 0
  429. and self.sliding_window is None
  430. and is_prompt)
  431. # These are seq_len/context_len capped to the sliding window.
  432. # They are passed to decode kernel.
  433. # We still need original seq_len/context_len to compute slot
  434. # mapping (and input position) below.
  435. curr_sliding_window_blocks = None
  436. sliding_seq_len = seq_len
  437. sliding_context_len = context_len
  438. # TODO: This is a hack to make sliding window work with
  439. # paged attn. We can remove it if we make paged attn kernel
  440. # to properly handle slinding window attn.
  441. if (self.sliding_window is not None and not is_prompt):
  442. curr_sliding_window_blocks = sliding_window_blocks
  443. if self.scheduler_config.use_v2_block_manager:
  444. # number of elements in last block
  445. suff_len = seq_len % self.block_size
  446. sliding_seq_len = min(
  447. seq_len, block_aligned_sliding_window + suff_len)
  448. if suff_len > 0:
  449. curr_sliding_window_blocks += 1
  450. else:
  451. sliding_seq_len = min(seq_len, self.sliding_window)
  452. sliding_context_len = sliding_seq_len - 1
  453. # TODO: Combine chunked prefill and prefix caching by
  454. # only allowing multiple of block_size chunk size.
  455. # NOTE: This only works for oooooooxxx style attention.
  456. if prefix_cache_hit:
  457. assert computed_block_nums is not None
  458. context_len = len(computed_block_nums) * self.block_size
  459. tokens = tokens[context_len:]
  460. # need to think what to set it to when we have both sliding
  461. # window and prefix caching...
  462. assert self.sliding_window is None, \
  463. "Prefix caching is not supported with sliding window"
  464. sliding_context_len = context_len
  465. if self.attn_backend.get_name() == "flash-attn":
  466. # NOTE: For flash-attn, the block table should
  467. # include the entries for the incoming prefill tokens.
  468. # TODO: This is a temporary fix. We should
  469. # provide a unified interface for different backends.
  470. block_table = seq_group_metadata.block_tables[seq_id]
  471. else:
  472. block_table = computed_block_nums
  473. elif (self.scheduler_config.chunked_prefill_enabled
  474. or not is_prompt):
  475. if seq_group_metadata.block_tables is not None:
  476. # chunked prefill or decode
  477. block_table = seq_group_metadata.block_tables[seq_id]
  478. if curr_sliding_window_blocks is not None:
  479. block_table = block_table[
  480. -curr_sliding_window_blocks:]
  481. else:
  482. # Only happens when memory profiling runs.
  483. block_table = []
  484. else:
  485. # Prefill without chunked prefill or memory profiling.
  486. block_table = []
  487. block_tables.append(block_table)
  488. seq_lens.append(sliding_seq_len)
  489. context_lens.append(sliding_context_len)
  490. query_len = sliding_seq_len - sliding_context_len
  491. query_lens.append(query_len)
  492. input_tokens.extend(tokens)
  493. input_positions.extend(list(range(context_len, seq_len)))
  494. lora_id = seq_group_metadata.lora_int_id
  495. prompt_adapter_id = seq_group_metadata.prompt_adapter_id
  496. if is_prompt:
  497. assert len(seq_ids) == 1
  498. num_prefills += 1
  499. num_prefill_tokens += len(tokens)
  500. decode_only = False
  501. prefill_seq_lens.append(seq_len)
  502. else:
  503. assert query_len == 1, (
  504. "seq_len: {}, context_len: {}, query_len: {}".format(
  505. seq_len, context_len, query_len))
  506. num_decode_tokens += query_len
  507. decode_seq_lens.append(sliding_seq_len)
  508. if lora_id > 0:
  509. lora_requests.add(seq_group_metadata.lora_request)
  510. lora_index_mapping += [lora_id] * query_len
  511. lora_prompt_mapping.extend(
  512. [lora_id] *
  513. (query_len if seq_group_metadata.sampling_params
  514. and seq_group_metadata.sampling_params.prompt_logprobs
  515. is not None else 1))
  516. mm_data = seq_group_metadata.multi_modal_data
  517. if mm_data:
  518. # Process multi-modal data
  519. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  520. multi_modal_inputs_list.append(mm_kwargs)
  521. if prompt_adapter_id > 0 and is_prompt:
  522. prompt_adapter_requests.add(
  523. seq_group_metadata.prompt_adapter_request)
  524. num_tokens = seq_group_metadata.\
  525. prompt_adapter_num_virtual_tokens
  526. pm = [prompt_adapter_id
  527. ] * num_tokens + [0] * (query_len - num_tokens)
  528. prompt_adapter_index_mapping += pm
  529. prompt_adapter_prompt_mapping.extend(
  530. [prompt_adapter_id] *
  531. (query_len if seq_group_metadata.sampling_params
  532. and seq_group_metadata.sampling_params.prompt_logprobs
  533. else 1))
  534. is_profile_run = _is_block_tables_empty(
  535. seq_group_metadata.block_tables)
  536. if is_profile_run:
  537. # During memory profiling, the block tables are not
  538. # initialized yet. In this case, we just use a dummy
  539. # slot mapping.
  540. # In embeddings, the block tables are {seq_id: None}.
  541. slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  542. continue
  543. # Compute the slot mapping.
  544. block_table = seq_group_metadata.block_tables[seq_id]
  545. # Mask the [0, start_idx) tokens of the prompt with
  546. # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
  547. # sliding_window). For example, if the prompt len is 10,
  548. # sliding window is 8, and block size is 4, the first two
  549. # tokens are masked and the slot mapping will be
  550. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  551. start_idx = 0
  552. if self.sliding_window is not None:
  553. if is_prompt:
  554. assert self.scheduler_config.use_v2_block_manager \
  555. or context_len == 0, (
  556. "Prefix caching is currently not supported with "
  557. "sliding window attention in V1 block manager")
  558. # It is an optimization. When it is decoding, it is always
  559. # 0. When prefill, we use it to not write slots to kv cache
  560. # to save memory.
  561. start_idx = max(0, query_len - self.sliding_window)
  562. for i in range(context_len, seq_len):
  563. if i < start_idx:
  564. slot_mapping.append(_PAD_SLOT_ID)
  565. continue
  566. block_number = block_table[i // self.block_size]
  567. block_offset = i % self.block_size
  568. slot = block_number * self.block_size + block_offset
  569. slot_mapping.append(slot)
  570. # Prepare input tensors for flashinfer
  571. if self.attn_backend.get_name() == "flashinfer":
  572. seq_len = seq_data.get_len()
  573. # Get the number of valid blocks based on sequence length.
  574. # If seq_len = 16, block_size = 16,
  575. # block_table_bound is 1 with 1 valid block.
  576. # If seq_len = 15, block_size = 16,
  577. # block_table_bound is 0 + 1 with 1 valid block.
  578. block_table_bound = seq_len // self.block_size + 1 \
  579. if seq_len % self.block_size != 0 \
  580. else seq_len // self.block_size
  581. paged_kv_indices.extend(block_table[:block_table_bound])
  582. paged_kv_indptr.append(paged_kv_indptr[-1] +
  583. block_table_bound)
  584. last_page_len = seq_len % self.block_size
  585. if last_page_len == 0:
  586. last_page_len = self.block_size
  587. paged_kv_last_page_len.append(last_page_len)
  588. batch_size = len(input_tokens)
  589. max_query_len = max(query_lens)
  590. max_prefill_seq_len = max(prefill_seq_lens, default=0)
  591. max_decode_seq_len = max(decode_seq_lens, default=0)
  592. # If cuda graph can be used, pad tensors accordingly.
  593. # See `capture_model` API for more details.
  594. # Aphrodite uses cuda graph only for decoding requests.
  595. use_captured_graph = (
  596. decode_only and not self.model_config.enforce_eager
  597. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  598. and max_decode_seq_len <= self.max_seq_len_to_capture)
  599. if use_captured_graph:
  600. graph_batch_size = _get_graph_batch_size(batch_size)
  601. assert graph_batch_size >= batch_size
  602. for _ in range(graph_batch_size - batch_size):
  603. input_tokens.append(0)
  604. input_positions.append(0)
  605. slot_mapping.append(_PAD_SLOT_ID)
  606. seq_lens.append(1)
  607. block_tables.append([])
  608. lora_index_mapping.append(0)
  609. prompt_adapter_index_mapping.append(0)
  610. if self.attn_backend.get_name() == "flashinfer":
  611. last_paged_kv_indptr = paged_kv_indptr[-1]
  612. paged_kv_indptr.append(last_paged_kv_indptr)
  613. paged_kv_last_page_len.append(0)
  614. batch_size = graph_batch_size
  615. num_decode_tokens = batch_size
  616. if use_captured_graph:
  617. # The shape of graph_block_tables is
  618. # [max batch size, max context len // block size].
  619. input_block_tables = self.graph_block_tables[:batch_size]
  620. for i, block_table in enumerate(block_tables):
  621. if block_table:
  622. input_block_tables[i, :len(block_table)] = block_table
  623. block_tables = torch.tensor(input_block_tables, device=self.device)
  624. else:
  625. max_block_table_len = max(
  626. len(block_table) for block_table in block_tables)
  627. block_tables = make_tensor_with_pad(
  628. block_tables,
  629. max_len=max_block_table_len,
  630. pad=0,
  631. dtype=torch.int,
  632. device=self.device,
  633. )
  634. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  635. context_lens_tensor = torch.tensor(context_lens,
  636. dtype=torch.int,
  637. device=self.device)
  638. seq_lens_tensor = torch.tensor(seq_lens,
  639. dtype=torch.int,
  640. device=self.device)
  641. query_lens_tensor = torch.tensor(query_lens,
  642. dtype=torch.long,
  643. device=self.device)
  644. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  645. dtype=torch.int32,
  646. device=self.device)
  647. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  648. dtype=torch.int32,
  649. device=self.device)
  650. torch.cumsum(seq_lens_tensor,
  651. dim=0,
  652. dtype=seq_start_loc.dtype,
  653. out=seq_start_loc[1:])
  654. torch.cumsum(query_lens_tensor,
  655. dim=0,
  656. dtype=query_start_loc.dtype,
  657. out=query_start_loc[1:])
  658. input_tokens_tensor = torch.tensor(input_tokens,
  659. dtype=torch.long,
  660. device=self.device)
  661. input_positions_tensor = torch.tensor(input_positions,
  662. dtype=torch.long,
  663. device=self.device)
  664. slot_mapping_tensor = torch.tensor(slot_mapping,
  665. dtype=torch.long,
  666. device=self.device)
  667. logits_soft_cap = getattr(self.model_config.hf_config,
  668. 'attn_logit_softcapping', None)
  669. if logits_soft_cap is not None and self.attn_backend.get_name(
  670. ) != "flashinfer":
  671. raise ValueError("Please use Flashinfer backend for models with"
  672. "logits_soft_cap (i.e., Gemma-2)."
  673. " Otherwise, the output might be wrong."
  674. " Set Flashinfer backend by "
  675. "export APHRODITE_ATTENTION_BACKEND=FLASHINFER.")
  676. if self.attn_backend.get_name() == "flashinfer":
  677. if len(paged_kv_indptr) > 0:
  678. paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
  679. device='cpu',
  680. dtype=torch.int)
  681. paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
  682. device='cpu',
  683. dtype=torch.int)
  684. paged_kv_last_page_len_tensor = torch.tensor(
  685. paged_kv_last_page_len, device='cpu', dtype=torch.int)
  686. else:
  687. paged_kv_indices_tensor = None
  688. paged_kv_indptr_tensor = None
  689. paged_kv_last_page_len_tensor = None
  690. kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
  691. self.model_config.dtype)
  692. attn_metadata = self.attn_backend.make_metadata(
  693. num_prefills=num_prefills,
  694. slot_mapping=slot_mapping_tensor,
  695. num_prefill_tokens=num_prefill_tokens,
  696. num_decode_tokens=num_decode_tokens,
  697. max_prefill_seq_len=max_prefill_seq_len,
  698. block_tables=block_tables,
  699. paged_kv_indptr=paged_kv_indptr_tensor,
  700. paged_kv_indices=paged_kv_indices_tensor,
  701. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  702. num_qo_heads=self.model_config.get_num_attention_heads(
  703. self.parallel_config),
  704. num_kv_heads=self.model_config.get_num_kv_heads(
  705. self.parallel_config),
  706. head_dim=self.model_config.get_head_size(),
  707. page_size=self.block_size,
  708. seq_start_loc=seq_start_loc,
  709. query_start_loc=query_start_loc,
  710. device=self.device,
  711. data_type=kv_cache_dtype,
  712. use_cuda_graph=use_captured_graph,
  713. logits_soft_cap=logits_soft_cap)
  714. else:
  715. attn_metadata = self.attn_backend.make_metadata(
  716. num_prefills=num_prefills,
  717. slot_mapping=slot_mapping_tensor,
  718. num_prefill_tokens=num_prefill_tokens,
  719. num_decode_tokens=num_decode_tokens,
  720. seq_lens=seq_lens,
  721. seq_lens_tensor=seq_lens_tensor,
  722. max_query_len=max_query_len,
  723. max_prefill_seq_len=max_prefill_seq_len,
  724. max_decode_seq_len=max_decode_seq_len,
  725. query_start_loc=query_start_loc,
  726. seq_start_loc=seq_start_loc,
  727. context_lens_tensor=context_lens_tensor,
  728. block_tables=block_tables,
  729. use_cuda_graph=use_captured_graph,
  730. )
  731. if self.lora_config:
  732. lora_mapping = LoRAMapping(
  733. lora_index_mapping,
  734. lora_prompt_mapping,
  735. )
  736. else:
  737. lora_mapping = None
  738. if self.prompt_adapter_config:
  739. prompt_adapter_mapping = PromptAdapterMapping(
  740. prompt_adapter_index_mapping,
  741. prompt_adapter_prompt_mapping,
  742. )
  743. else:
  744. prompt_adapter_mapping = None
  745. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
  746. device=self.device)
  747. request_ids_to_seq_ids = {
  748. seq_group_metadata.request_id:
  749. list(seq_group_metadata.seq_data.keys())
  750. for seq_group_metadata in seq_group_metadata_list
  751. }
  752. return self._model_input_cls(
  753. input_tokens=input_tokens_tensor,
  754. input_positions=input_positions_tensor,
  755. attn_metadata=attn_metadata,
  756. seq_lens=seq_lens,
  757. query_lens=query_lens,
  758. lora_mapping=lora_mapping,
  759. lora_requests=lora_requests,
  760. multi_modal_kwargs=multi_modal_kwargs,
  761. request_ids_to_seq_ids=request_ids_to_seq_ids,
  762. finished_requests_ids=finished_requests_ids,
  763. prompt_adapter_mapping=prompt_adapter_mapping,
  764. prompt_adapter_requests=prompt_adapter_requests,
  765. )
  766. @torch.inference_mode()
  767. def profile_run(self) -> None:
  768. # Enable top-k sampling to reflect the accurate memory usage.
  769. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  770. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  771. max_num_seqs = self.scheduler_config.max_num_seqs
  772. # This represents the maximum number of different requests
  773. # that will have unique loras, an therefore the max amount of memory
  774. # consumption create dummy lora request copies from the lora request
  775. # passed in, which contains a lora from the lora warmup path.
  776. dummy_lora_requests: List[LoRARequest] = []
  777. dummy_lora_requests_per_seq: List[LoRARequest] = []
  778. if self.lora_config:
  779. assert self.lora_manager is not None
  780. with self.lora_manager.dummy_lora_cache():
  781. for idx in range(self.lora_config.max_loras):
  782. lora_id = idx + 1
  783. dummy_lora_request = LoRARequest(
  784. lora_name=f"warmup_{lora_id}",
  785. lora_int_id=lora_id,
  786. lora_local_path="/not/a/real/path",
  787. )
  788. self.lora_manager.add_dummy_lora(dummy_lora_request,
  789. rank=LORA_WARMUP_RANK)
  790. dummy_lora_requests.append(dummy_lora_request)
  791. dummy_lora_requests_per_seq = [
  792. dummy_lora_requests[idx % len(dummy_lora_requests)]
  793. for idx in range(max_num_seqs)
  794. ]
  795. # Profile memory usage with max_num_sequences sequences and the total
  796. # number of tokens equal to max_num_batched_tokens.
  797. seqs: List[SequenceGroupMetadata] = []
  798. # Additional GPU memory may be needed for vision encoding, which needs
  799. # to be accounted for when calculating the GPU blocks for
  800. # Aphrodite blocker manager.
  801. # To exercise the worst scenario for GPU memory consumption,
  802. # the number of seqs (batch_size) is chosen to maximize the number
  803. # of images processed.
  804. model_config = self.model_config
  805. if supports_vision(self.model):
  806. max_mm_tokens = MULTIMODAL_REGISTRY \
  807. .get_max_multimodal_tokens(model_config)
  808. max_num_seqs_orig = max_num_seqs
  809. max_num_seqs = min(max_num_seqs,
  810. max_num_batched_tokens // max_mm_tokens)
  811. if max_num_seqs < 1:
  812. expr = (f"min({max_num_seqs_orig}, "
  813. f"{max_num_batched_tokens} // {max_mm_tokens})")
  814. logger.warning(
  815. f"Computed max_num_seqs ({expr}) to be less than 1. "
  816. "Setting it to the minimum value of 1.")
  817. max_num_seqs = 1
  818. batch_size = 0
  819. for group_id in range(max_num_seqs):
  820. seq_len = (max_num_batched_tokens // max_num_seqs +
  821. (group_id < max_num_batched_tokens % max_num_seqs))
  822. batch_size += seq_len
  823. seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
  824. .dummy_data_for_profiling(model_config, seq_len)
  825. # Having more tokens is over-conservative but otherwise fine
  826. assert len(seq_data.prompt_token_ids) >= seq_len, (
  827. f"Expected at least {seq_len} dummy tokens for profiling, "
  828. f"but got: {len(seq_data.prompt_token_ids)}")
  829. seq = SequenceGroupMetadata(
  830. request_id=str(group_id),
  831. is_prompt=True,
  832. seq_data={group_id: seq_data},
  833. sampling_params=sampling_params,
  834. block_tables=None,
  835. lora_request=dummy_lora_requests_per_seq[group_id]
  836. if dummy_lora_requests_per_seq else None,
  837. multi_modal_data=dummy_multi_modal_data,
  838. )
  839. seqs.append(seq)
  840. # Run the model with the dummy inputs.
  841. num_layers = self.model_config.get_num_layers(self.parallel_config)
  842. kv_caches = [None] * num_layers
  843. finished_requests_ids = [seq.request_id for seq in seqs]
  844. model_input = self.prepare_model_input(
  845. seqs, finished_requests_ids=finished_requests_ids)
  846. intermediate_tensors = None
  847. if not get_pp_group().is_first_rank:
  848. intermediate_tensors = self.model.make_empty_intermediate_tensors(
  849. batch_size=batch_size,
  850. dtype=self.model_config.dtype,
  851. device=self.device)
  852. self.execute_model(model_input, kv_caches, intermediate_tensors)
  853. torch.cuda.synchronize()
  854. return
  855. def remove_all_loras(self):
  856. if not self.lora_manager:
  857. raise RuntimeError("LoRA is not enabled.")
  858. self.lora_manager.remove_all_adapters()
  859. def set_active_loras(self, lora_requests: Set[LoRARequest],
  860. lora_mapping: LoRAMapping) -> None:
  861. if not self.lora_manager:
  862. raise RuntimeError("LoRA is not enabled.")
  863. self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
  864. def add_lora(self, lora_request: LoRARequest) -> bool:
  865. if not self.lora_manager:
  866. raise RuntimeError("LoRA is not enabled.")
  867. return self.lora_manager.add_adapter(lora_request)
  868. def remove_lora(self, lora_id: int) -> bool:
  869. if not self.lora_manager:
  870. raise RuntimeError("LoRA is not enabled.")
  871. return self.lora_manager.remove_adapter(lora_id)
  872. def pin_lora(self, lora_id: int) -> bool:
  873. if not self.lora_manager:
  874. raise RuntimeError("LoRA is not enabled.")
  875. return self.lora_manager.pin_adapter(lora_id)
  876. def list_loras(self) -> Set[int]:
  877. if not self.lora_manager:
  878. raise RuntimeError("LoRA is not enabled.")
  879. return self.lora_manager.list_adapters()
  880. def remove_all_prompt_adapters(self):
  881. if not self.prompt_adapter_manager:
  882. raise RuntimeError("PromptAdapter is not enabled.")
  883. self.prompt_adapter_manager.remove_all_adapters()
  884. def set_active_prompt_adapters(
  885. self, prompt_adapter_requests: Set[PromptAdapterRequest],
  886. prompt_adapter_mapping: PromptAdapterMapping) -> None:
  887. if not self.prompt_adapter_manager:
  888. raise RuntimeError("PromptAdapter is not enabled.")
  889. self.prompt_adapter_manager.set_active_adapters(
  890. prompt_adapter_requests, prompt_adapter_mapping)
  891. def add_prompt_adapter(
  892. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  893. if not self.prompt_adapter_manager:
  894. raise RuntimeError("PromptAdapter is not enabled.")
  895. return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
  896. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  897. if not self.prompt_adapter_manager:
  898. raise RuntimeError("PromptAdapter is not enabled.")
  899. return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
  900. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  901. if not self.prompt_adapter_manager:
  902. raise RuntimeError("PromptAdapter is not enabled.")
  903. return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
  904. def list_prompt_adapters(self) -> Set[int]:
  905. if not self.prompt_adapter_manager:
  906. raise RuntimeError("PromptAdapter is not enabled.")
  907. return self.prompt_adapter_manager.list_adapters()
  908. @torch.inference_mode()
  909. def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
  910. """Cuda graph capture a model.
  911. Note that CUDA graph's performance gain is negligible if number
  912. of batched tokens are larger than 200. And since CUDA graph
  913. requires fixed sized tensors, supporting large/variable batch
  914. size requires high GPU memory overhead. Thus, Aphrodite only captures
  915. decoding requests. Mixed batch (chunked prefill + decoding) or
  916. prefill requests are not captured.
  917. Since it is used for decoding-only, it assumes there's only 1 token
  918. per sequence in the batch.
  919. """
  920. assert not self.model_config.enforce_eager
  921. logger.info("Capturing the model for CUDA graphs. This may lead to "
  922. "unexpected consequences if the model is not static. To "
  923. "run the model in eager mode, set 'enforce_eager=True' or "
  924. "use '--enforce-eager' in the CLI.")
  925. logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
  926. "If you are running out of memory, consider decreasing "
  927. "`gpu_memory_utilization` or enforcing eager mode. "
  928. "You can also reduce the `max_num_seqs` as needed "
  929. "to decrease memory usage.")
  930. start_time = time.perf_counter()
  931. # Prepare dummy inputs. These will be reused for all batch sizes.
  932. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  933. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  934. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  935. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  936. slot_mapping.fill_(_PAD_SLOT_ID)
  937. seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  938. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  939. intermediate_inputs = None
  940. if not get_pp_group().is_first_rank:
  941. intermediate_inputs = self.model.make_empty_intermediate_tensors(
  942. batch_size=max_batch_size,
  943. dtype=self.model_config.dtype,
  944. device=self.device)
  945. # Prepare buffer for outputs. These will be reused for all batch sizes.
  946. # It will be filled after the first graph capture.
  947. hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
  948. None
  949. ] * self.parallel_config.pipeline_parallel_size
  950. graph_batch_size = _get_graph_batch_size(
  951. self.scheduler_config.max_num_seqs)
  952. batch_size_capture_list = [
  953. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  954. ]
  955. if self.attn_backend.get_name() == "flashinfer":
  956. # For flashinfer, different batch sizes will share the
  957. # same workspace buffer.
  958. decode_workspace_buffer = \
  959. torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
  960. dtype=torch.uint8,
  961. device=self.device)
  962. indices_buffer = torch.empty(max_batch_size *
  963. self.cache_config.num_gpu_blocks,
  964. dtype=torch.int32,
  965. device=self.device)
  966. indptr_buffer = torch.empty(max_batch_size + 1,
  967. dtype=torch.int32,
  968. device=self.device)
  969. last_page_len_buffer = torch.empty(max_batch_size,
  970. dtype=torch.int32,
  971. device=self.device)
  972. with graph_capture() as graph_capture_context:
  973. # NOTE: Capturing the largest batch size first may help reduce the
  974. # memory usage of CUDA graph.
  975. for virtual_engine in range(
  976. self.parallel_config.pipeline_parallel_size):
  977. for batch_size in reversed(batch_size_capture_list):
  978. if self.attn_backend.get_name() == "flashinfer":
  979. indptr_buffer = indptr_buffer[:batch_size + 1]
  980. last_page_len_buffer = last_page_len_buffer[:
  981. batch_size]
  982. num_qo_heads = (
  983. self.model_config.get_num_attention_heads(
  984. self.parallel_config))
  985. num_kv_heads = self.model_config.get_num_kv_heads(
  986. self.parallel_config)
  987. if num_qo_heads // num_kv_heads >= 4:
  988. use_tensor_cores = True
  989. else:
  990. use_tensor_cores = False
  991. decode_wrapper = \
  992. CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
  993. decode_workspace_buffer, indptr_buffer,
  994. indices_buffer, last_page_len_buffer, "NHD",
  995. use_tensor_cores)
  996. kv_cache_dtype = get_kv_cache_torch_dtype(
  997. self.kv_cache_dtype, self.model_config.dtype)
  998. paged_kv_indptr_tensor_host = torch.arange(
  999. 0, batch_size + 1, dtype=torch.int32)
  1000. paged_kv_indices_tensor_host = torch.arange(
  1001. 0, batch_size, dtype=torch.int32)
  1002. paged_kv_last_page_len_tensor_host = torch.full(
  1003. (batch_size, ), self.block_size, dtype=torch.int32)
  1004. query_start_loc_host = torch.arange(0,
  1005. batch_size + 1,
  1006. dtype=torch.int32)
  1007. attn_metadata = self.attn_backend.make_metadata(
  1008. num_prefills=0,
  1009. slot_mapping=slot_mapping[:batch_size],
  1010. num_prefill_tokens=0,
  1011. num_decode_tokens=batch_size,
  1012. max_prefill_seq_len=0,
  1013. block_tables=block_tables,
  1014. paged_kv_indptr=paged_kv_indptr_tensor_host,
  1015. paged_kv_indices=paged_kv_indices_tensor_host,
  1016. paged_kv_last_page_len=
  1017. paged_kv_last_page_len_tensor_host,
  1018. num_qo_heads=num_qo_heads,
  1019. num_kv_heads=num_kv_heads,
  1020. head_dim=self.model_config.get_head_size(),
  1021. page_size=self.block_size,
  1022. seq_start_loc=None,
  1023. query_start_loc=query_start_loc_host,
  1024. device=self.device,
  1025. data_type=kv_cache_dtype,
  1026. use_cuda_graph=True,
  1027. decode_wrapper=decode_wrapper,
  1028. prefill_wrapper=None)
  1029. attn_metadata.begin_forward()
  1030. else:
  1031. attn_metadata = self.attn_backend.make_metadata(
  1032. num_prefills=0,
  1033. num_prefill_tokens=0,
  1034. num_decode_tokens=batch_size,
  1035. slot_mapping=slot_mapping[:batch_size],
  1036. seq_lens=None,
  1037. seq_lens_tensor=seq_lens[:batch_size],
  1038. max_query_len=None,
  1039. max_prefill_seq_len=0,
  1040. max_decode_seq_len=self.max_seq_len_to_capture,
  1041. query_start_loc=None,
  1042. seq_start_loc=None,
  1043. context_lens_tensor=None,
  1044. block_tables=block_tables[:batch_size],
  1045. use_cuda_graph=True,
  1046. )
  1047. if self.lora_config:
  1048. lora_mapping = LoRAMapping(
  1049. [0] * batch_size,
  1050. [0] * batch_size,
  1051. )
  1052. self.set_active_loras(set(), lora_mapping)
  1053. if self.prompt_adapter_config:
  1054. prompt_adapter_mapping = PromptAdapterMapping(
  1055. [-1] * batch_size,
  1056. [-1] * batch_size,
  1057. )
  1058. self.set_active_prompt_adapters(
  1059. set(), prompt_adapter_mapping)
  1060. graph_runner = CUDAGraphRunner(
  1061. self.model, self.attn_backend.get_name())
  1062. if self.attn_backend.get_name() == "flashinfer":
  1063. graph_runner.flashinfer_indptr_buffer = indptr_buffer
  1064. graph_runner.flashinfer_indices_buffer = indices_buffer
  1065. graph_runner.flashinfer_last_page_len_buffer = \
  1066. last_page_len_buffer
  1067. graph_runner.flashinfer_decode_workspace_buffer = \
  1068. decode_workspace_buffer
  1069. graph_runner.flashinfer_decode_wrapper = \
  1070. decode_wrapper
  1071. capture_inputs = {
  1072. "input_ids":
  1073. input_tokens[:batch_size],
  1074. "positions":
  1075. input_positions[:batch_size],
  1076. "hidden_or_intermediate_states":
  1077. hidden_or_intermediate_states[
  1078. virtual_engine] # type: ignore
  1079. [:batch_size]
  1080. if hidden_or_intermediate_states[virtual_engine]
  1081. is not None else None,
  1082. "intermediate_inputs":
  1083. intermediate_inputs[:batch_size]
  1084. if intermediate_inputs is not None else None,
  1085. "kv_caches":
  1086. kv_caches[virtual_engine],
  1087. "attn_metadata":
  1088. attn_metadata,
  1089. "memory_pool":
  1090. self.graph_memory_pool,
  1091. "stream":
  1092. graph_capture_context.stream
  1093. }
  1094. if self.has_seqlen_agnostic:
  1095. # Only used by Mamba-based models CUDA graph atm (Jamba)
  1096. capture_inputs.update({
  1097. "seqlen_agnostic_capture_inputs":
  1098. self.model.get_seqlen_agnostic_capture_inputs(
  1099. batch_size)
  1100. })
  1101. graph_runner.capture(**capture_inputs)
  1102. self.graph_memory_pool = graph_runner.graph.pool()
  1103. self.graph_runners[virtual_engine][batch_size] = (
  1104. graph_runner)
  1105. end_time = time.perf_counter()
  1106. elapsed_time = end_time - start_time
  1107. # This usually takes < 10 seconds.
  1108. logger.info(f"Graph capturing finished in {elapsed_time:2f} secs.")
  1109. @property
  1110. def vocab_size(self) -> int:
  1111. return self.model_config.get_vocab_size()
  1112. class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
  1113. """
  1114. GPU model runner with sampling step.
  1115. """
  1116. _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
  1117. ModelInputForGPUWithSamplingMetadata)
  1118. def make_model_input_from_broadcasted_tensor_dict(
  1119. self,
  1120. tensor_dict: Dict[str, Any],
  1121. ) -> ModelInputForGPUWithSamplingMetadata:
  1122. model_input = \
  1123. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  1124. tensor_dict,
  1125. attn_backend=self.attn_backend,
  1126. )
  1127. return model_input
  1128. def prepare_model_input(
  1129. self,
  1130. seq_group_metadata_list: List[SequenceGroupMetadata],
  1131. virtual_engine: int = 0,
  1132. finished_requests_ids: Optional[List[str]] = None
  1133. ) -> ModelInputForGPUWithSamplingMetadata:
  1134. """Prepare the model input based on a given sequence group, including
  1135. metadata for the sampling step.
  1136. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  1137. The result tensors and data structure also batches input in prefill
  1138. -> decode order. For example,
  1139. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1140. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1141. If cuda graph is required, this API automatically pads inputs.
  1142. """
  1143. model_input = self._prepare_model_input_tensors(
  1144. seq_group_metadata_list, finished_requests_ids)
  1145. sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
  1146. model_input.seq_lens,
  1147. model_input.query_lens,
  1148. self.device,
  1149. self.pin_memory)
  1150. is_prompt = (seq_group_metadata_list[0].is_prompt
  1151. if seq_group_metadata_list else None)
  1152. return dataclasses.replace(model_input,
  1153. sampling_metadata=sampling_metadata,
  1154. is_prompt=is_prompt,
  1155. virtual_engine=virtual_engine)
  1156. @torch.inference_mode()
  1157. def execute_model(
  1158. self,
  1159. model_input: ModelInputForGPUWithSamplingMetadata,
  1160. kv_caches: List[torch.Tensor],
  1161. intermediate_tensors: Optional[IntermediateTensors] = None,
  1162. num_steps: int = 1,
  1163. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  1164. if num_steps > 1:
  1165. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  1166. if self.lora_config:
  1167. assert model_input.lora_requests is not None
  1168. assert model_input.lora_mapping is not None
  1169. self.set_active_loras(model_input.lora_requests,
  1170. model_input.lora_mapping)
  1171. if self.prompt_adapter_config:
  1172. assert model_input.prompt_adapter_requests is not None
  1173. assert model_input.prompt_adapter_mapping is not None
  1174. self.set_active_prompt_adapters(
  1175. model_input.prompt_adapter_requests,
  1176. model_input.prompt_adapter_mapping)
  1177. if self.attn_backend.get_name() == "flashinfer":
  1178. assert model_input.attn_metadata is not None
  1179. assert model_input.input_tokens is not None
  1180. if self.flashinfer_decode_workspace_buffer is None:
  1181. self.flashinfer_decode_workspace_buffer = torch.empty(
  1182. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  1183. dtype=torch.uint8,
  1184. device=self.device)
  1185. self.flashinfer_decode_wrapper = \
  1186. BatchDecodeWithPagedKVCacheWrapper(
  1187. self.flashinfer_decode_workspace_buffer, "NHD")
  1188. self.flashinfer_prefill_workspace_buffer = torch.empty(
  1189. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  1190. dtype=torch.uint8,
  1191. device=self.device)
  1192. self.flashinfer_prefill_wrapper = \
  1193. BatchPrefillWithPagedKVCacheWrapper(
  1194. self.flashinfer_prefill_workspace_buffer, "NHD")
  1195. model_input.attn_metadata.prefill_wrapper = \
  1196. self.flashinfer_prefill_wrapper
  1197. if model_input.attn_metadata.use_cuda_graph:
  1198. batch_size = model_input.input_tokens.shape[0]
  1199. model_input.attn_metadata.decode_wrapper = self.graph_runners[
  1200. model_input.
  1201. virtual_engine][batch_size].flashinfer_decode_wrapper
  1202. else:
  1203. model_input.attn_metadata.decode_wrapper = \
  1204. self.flashinfer_decode_wrapper
  1205. model_input.attn_metadata.begin_forward()
  1206. # Currently cuda graph is only supported by the decode phase.
  1207. assert model_input.attn_metadata is not None
  1208. prefill_meta = model_input.attn_metadata.prefill_metadata
  1209. decode_meta = model_input.attn_metadata.decode_metadata
  1210. # TODO: We can remove this once all
  1211. # virtual engines share the same kv cache.
  1212. virtual_engine = model_input.virtual_engine
  1213. if prefill_meta is None and decode_meta.use_cuda_graph:
  1214. assert model_input.input_tokens is not None
  1215. graph_batch_size = model_input.input_tokens.shape[0]
  1216. model_executable = self.graph_runners[virtual_engine][
  1217. graph_batch_size]
  1218. else:
  1219. model_executable = self.model
  1220. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  1221. seqlen_agnostic_kwargs = {
  1222. "finished_requests_ids": model_input.finished_requests_ids,
  1223. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  1224. } if self.has_seqlen_agnostic else {}
  1225. hidden_or_intermediate_states = model_executable(
  1226. input_ids=model_input.input_tokens,
  1227. positions=model_input.input_positions,
  1228. kv_caches=kv_caches,
  1229. attn_metadata=model_input.attn_metadata,
  1230. intermediate_tensors=intermediate_tensors,
  1231. **multi_modal_kwargs,
  1232. **seqlen_agnostic_kwargs,
  1233. )
  1234. # Compute the logits in the last pipeline stage.
  1235. if not get_pp_group().is_last_rank:
  1236. return hidden_or_intermediate_states
  1237. logits = self.model.compute_logits(hidden_or_intermediate_states,
  1238. model_input.sampling_metadata)
  1239. if not self.is_driver_worker:
  1240. return []
  1241. # Sample the next token.
  1242. output: SamplerOutput = self.model.sample(
  1243. logits=logits,
  1244. sampling_metadata=model_input.sampling_metadata,
  1245. )
  1246. if self.return_hidden_states:
  1247. # we only need to pass hidden states of most recent token
  1248. assert model_input.sampling_metadata is not None
  1249. indices = model_input.sampling_metadata.selected_token_indices
  1250. if model_input.is_prompt:
  1251. hidden_states = hidden_or_intermediate_states.index_select(
  1252. 0, indices)
  1253. elif decode_meta.use_cuda_graph:
  1254. hidden_states = hidden_or_intermediate_states[:len(indices)]
  1255. else:
  1256. hidden_states = hidden_or_intermediate_states
  1257. output.hidden_states = hidden_states
  1258. return [output]
  1259. class CUDAGraphRunner:
  1260. def __init__(self, model: nn.Module, backend_name: str):
  1261. self.model = model
  1262. self.backend_name = backend_name
  1263. self.input_buffers: Dict[str, torch.Tensor] = {}
  1264. self.output_buffers: Dict[str, torch.Tensor] = {}
  1265. self._graph: Optional[torch.cuda.CUDAGraph] = None
  1266. self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
  1267. self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
  1268. self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
  1269. self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
  1270. self.flashinfer_decode_wrapper: Optional[
  1271. CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
  1272. @property
  1273. def graph(self):
  1274. assert self._graph is not None
  1275. return self._graph
  1276. def capture(
  1277. self,
  1278. input_ids: torch.Tensor,
  1279. positions: torch.Tensor,
  1280. hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
  1281. torch.Tensor]],
  1282. intermediate_inputs: Optional[IntermediateTensors],
  1283. kv_caches: List[torch.Tensor],
  1284. attn_metadata: AttentionMetadata,
  1285. memory_pool: Optional[Tuple[int, int]],
  1286. stream: torch.cuda.Stream,
  1287. **kwargs,
  1288. ) -> Union[torch.Tensor, IntermediateTensors]:
  1289. assert self._graph is None
  1290. # Run the model a few times without capturing the graph.
  1291. # This is to make sure that the captured graph does not include the
  1292. # kernel launches for initial benchmarking (e.g., Triton autotune).
  1293. # Note one iteration is not enough for torch.jit.script
  1294. for _ in range(_NUM_WARMUP_ITERS):
  1295. self.model(
  1296. input_ids,
  1297. positions,
  1298. kv_caches,
  1299. attn_metadata,
  1300. intermediate_inputs,
  1301. **kwargs,
  1302. )
  1303. torch.cuda.synchronize()
  1304. # Capture the graph.
  1305. self._graph = torch.cuda.CUDAGraph()
  1306. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  1307. output_hidden_or_intermediate_states = self.model(
  1308. input_ids,
  1309. positions,
  1310. kv_caches,
  1311. attn_metadata,
  1312. intermediate_inputs,
  1313. **kwargs,
  1314. )
  1315. if hidden_or_intermediate_states is not None:
  1316. if get_pp_group().is_last_rank:
  1317. hidden_or_intermediate_states.copy_(
  1318. output_hidden_or_intermediate_states)
  1319. else:
  1320. for key in hidden_or_intermediate_states.tensors:
  1321. hidden_or_intermediate_states[key].copy_(
  1322. output_hidden_or_intermediate_states[key])
  1323. else:
  1324. hidden_or_intermediate_states = (
  1325. output_hidden_or_intermediate_states)
  1326. del output_hidden_or_intermediate_states
  1327. # make sure `output_hidden_states` is deleted
  1328. # in the graph's memory pool
  1329. gc.collect()
  1330. torch.cuda.synchronize()
  1331. # Save the input and output buffers.
  1332. if self.backend_name == "flashinfer":
  1333. self.input_buffers = {
  1334. "input_ids": input_ids,
  1335. "positions": positions,
  1336. "kv_caches": kv_caches,
  1337. "slot_mapping": attn_metadata.slot_mapping,
  1338. **kwargs,
  1339. }
  1340. else:
  1341. self.input_buffers = {
  1342. "input_ids": input_ids,
  1343. "positions": positions,
  1344. "kv_caches": kv_caches,
  1345. "slot_mapping": attn_metadata.slot_mapping,
  1346. "seq_lens_tensor":
  1347. attn_metadata.decode_metadata.seq_lens_tensor,
  1348. "block_tables": attn_metadata.decode_metadata.block_tables,
  1349. **kwargs,
  1350. }
  1351. if intermediate_inputs is not None:
  1352. self.input_buffers.update(intermediate_inputs.tensors)
  1353. if get_pp_group().is_last_rank:
  1354. self.output_buffers = {
  1355. "hidden_states": hidden_or_intermediate_states
  1356. }
  1357. else:
  1358. self.output_buffers = hidden_or_intermediate_states
  1359. return hidden_or_intermediate_states
  1360. def forward(
  1361. self,
  1362. input_ids: torch.Tensor,
  1363. positions: torch.Tensor,
  1364. kv_caches: List[torch.Tensor],
  1365. attn_metadata: AttentionMetadata,
  1366. intermediate_tensors: Optional[IntermediateTensors],
  1367. **kwargs,
  1368. ) -> torch.Tensor:
  1369. # KV caches are fixed tensors, so we don't need to copy them.
  1370. del kv_caches
  1371. # Copy the input tensors to the input buffers.
  1372. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1373. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1374. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  1375. non_blocking=True)
  1376. if self.backend_name != "flashinfer":
  1377. self.input_buffers["seq_lens_tensor"].copy_(
  1378. attn_metadata.decode_metadata.seq_lens_tensor,
  1379. non_blocking=True)
  1380. self.input_buffers["block_tables"].copy_(
  1381. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  1382. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1383. self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
  1384. **kwargs)
  1385. if intermediate_tensors is not None:
  1386. for key in intermediate_tensors.tensors:
  1387. self.input_buffers[key].copy_(intermediate_tensors[key],
  1388. non_blocking=True)
  1389. # Run the graph.
  1390. self.graph.replay()
  1391. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1392. self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
  1393. **kwargs)
  1394. # Return the output tensor.
  1395. if get_pp_group().is_last_rank:
  1396. return self.output_buffers["hidden_states"]
  1397. return self.output_buffers
  1398. def __call__(self, *args, **kwargs):
  1399. return self.forward(*args, **kwargs)
  1400. def _get_graph_batch_size(batch_size: int) -> int:
  1401. """Returns the padded batch size given actual batch size.
  1402. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1403. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1404. """
  1405. if batch_size <= 2:
  1406. return batch_size
  1407. elif batch_size <= 4:
  1408. return 4
  1409. else:
  1410. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1411. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  1412. def _is_block_tables_empty(block_tables: Union[None, Dict]):
  1413. """
  1414. Check if block_tables is None or a dictionary with all None values.
  1415. """
  1416. if block_tables is None:
  1417. return True
  1418. if isinstance(block_tables, dict) and all(
  1419. value is None for value in block_tables.values()):
  1420. return True
  1421. return False