model_runner.py 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427
  1. import dataclasses
  2. import gc
  3. import time
  4. import warnings
  5. import weakref
  6. from collections import defaultdict
  7. from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
  8. Tuple, Type, TypeVar, Union)
  9. import numpy as np
  10. import torch
  11. import torch.distributed
  12. import torch.nn as nn
  13. from loguru import logger
  14. try:
  15. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  16. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  17. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  18. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  19. except ImportError:
  20. BatchDecodeWithPagedKVCacheWrapper = None
  21. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  22. BatchPrefillWithPagedKVCacheWrapper = None
  23. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  24. from aphrodite.attention import AttentionMetadata, get_attn_backend
  25. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  26. LoRAConfig, ModelConfig, MultiModalConfig,
  27. ParallelConfig, PromptAdapterConfig,
  28. SchedulerConfig)
  29. from aphrodite.common.sampling_params import SamplingParams
  30. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  31. SequenceGroupMetadata)
  32. from aphrodite.common.utils import (CudaMemoryProfiler,
  33. get_kv_cache_torch_dtype, is_hip,
  34. is_pin_memory_available)
  35. from aphrodite.distributed import get_pp_group
  36. from aphrodite.distributed.parallel_state import (
  37. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
  38. graph_capture)
  39. from aphrodite.inputs import INPUT_REGISTRY
  40. from aphrodite.lora.layers import LoRAMapping
  41. from aphrodite.lora.request import LoRARequest
  42. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  43. from aphrodite.modeling import SamplingMetadata
  44. from aphrodite.modeling.model_loader import get_model
  45. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  46. from aphrodite.modeling.models.interfaces import supports_lora, supports_vision
  47. from aphrodite.modeling.models.utils import set_cpu_offload_max_bytes
  48. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
  49. MultiModalInputs)
  50. from aphrodite.prompt_adapter.layers import PromptAdapterMapping
  51. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  52. from aphrodite.prompt_adapter.worker_manager import \
  53. LRUCacheWorkerPromptAdapterManager
  54. from aphrodite.task_handler.model_runner_base import (
  55. ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
  56. _add_attn_metadata_broadcastable_dict,
  57. _add_sampling_metadata_broadcastable_dict,
  58. _init_attn_metadata_from_tensor_dict,
  59. _init_sampling_metadata_from_tensor_dict)
  60. if TYPE_CHECKING:
  61. from aphrodite.attention.backends.abstract import AttentionBackend
  62. _PAD_SLOT_ID = -1
  63. LORA_WARMUP_RANK = 8
  64. _BATCH_SIZE_ALIGNMENT = 8
  65. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  66. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  67. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  68. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  69. ]
  70. _NUM_WARMUP_ITERS = 2
  71. TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
  72. @dataclasses.dataclass(frozen=True)
  73. class ModelInputForGPU(ModelRunnerInputBase):
  74. """
  75. This base class contains metadata needed for the base model forward pass
  76. but not metadata for possible additional steps, e.g., sampling. Model
  77. runners that run additional steps should subclass this method to add
  78. additional fields.
  79. """
  80. input_tokens: Optional[torch.Tensor] = None
  81. input_positions: Optional[torch.Tensor] = None
  82. seq_lens: Optional[List[int]] = None
  83. query_lens: Optional[List[int]] = None
  84. lora_mapping: Optional["LoRAMapping"] = None
  85. lora_requests: Optional[Set[LoRARequest]] = None
  86. attn_metadata: Optional["AttentionMetadata"] = None
  87. prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
  88. prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
  89. multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
  90. request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
  91. finished_requests_ids: Optional[List[str]] = None
  92. virtual_engine: int = 0
  93. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  94. tensor_dict = {
  95. "input_tokens": self.input_tokens,
  96. "input_positions": self.input_positions,
  97. "lora_requests": self.lora_requests,
  98. "lora_mapping": self.lora_mapping,
  99. "multi_modal_kwargs": self.multi_modal_kwargs,
  100. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  101. "prompt_adapter_requests": self.prompt_adapter_requests,
  102. "virtual_engine": self.virtual_engine,
  103. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  104. "finished_requests_ids": self.finished_requests_ids,
  105. }
  106. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  107. return tensor_dict
  108. @classmethod
  109. def from_broadcasted_tensor_dict(
  110. cls: Type[TModelInputForGPU],
  111. tensor_dict: Dict[str, Any],
  112. attn_backend: Optional["AttentionBackend"] = None,
  113. ) -> TModelInputForGPU:
  114. if attn_backend is not None:
  115. tensor_dict = _init_attn_metadata_from_tensor_dict(
  116. attn_backend, tensor_dict)
  117. return cls(**tensor_dict)
  118. @dataclasses.dataclass(frozen=True)
  119. class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
  120. """
  121. Used by the ModelRunner.
  122. """
  123. sampling_metadata: Optional["SamplingMetadata"] = None
  124. # Used for speculative decoding. We do not broadcast it because it is only
  125. # used by the driver worker.
  126. is_prompt: Optional[bool] = None
  127. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  128. tensor_dict = {
  129. "input_tokens": self.input_tokens,
  130. "input_positions": self.input_positions,
  131. "lora_requests": self.lora_requests,
  132. "lora_mapping": self.lora_mapping,
  133. "multi_modal_kwargs": self.multi_modal_kwargs,
  134. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  135. "prompt_adapter_requests": self.prompt_adapter_requests,
  136. "virtual_engine": self.virtual_engine,
  137. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  138. "finished_requests_ids": self.finished_requests_ids,
  139. }
  140. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  141. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  142. self.sampling_metadata)
  143. return tensor_dict
  144. @classmethod
  145. def from_broadcasted_tensor_dict(
  146. cls,
  147. tensor_dict: Dict[str, Any],
  148. attn_backend: Optional["AttentionBackend"] = None,
  149. ) -> "ModelInputForGPUWithSamplingMetadata":
  150. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  151. if attn_backend is not None:
  152. tensor_dict = _init_attn_metadata_from_tensor_dict(
  153. attn_backend, tensor_dict)
  154. return cls(**tensor_dict)
  155. class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
  156. """TBA"""
  157. def __init__(self,
  158. runner: "GPUModelRunnerBase",
  159. finished_requests_ids: Optional[List[str]] = None):
  160. super().__init__()
  161. self.runner = runner
  162. self.model_input_cls = self.runner._model_input_cls
  163. self.attn_backend = self.runner.attn_backend
  164. self.scheduler_config = self.runner.scheduler_config
  165. self.sliding_window = self.runner.sliding_window
  166. self.block_size = self.runner.block_size
  167. self.enable_lora = self.runner.lora_config is not None
  168. self.enable_prompt_adapter = (self.runner.prompt_adapter_config
  169. is not None)
  170. self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
  171. self.finished_requests_ids = finished_requests_ids
  172. self.decode_only = True
  173. # Common inputs.
  174. self.input_tokens: List[int] = []
  175. self.input_positions: List[int] = []
  176. self.seq_lens: List[int] = []
  177. self.query_lens: List[int] = []
  178. self.max_decode_seq_len: int = 0
  179. self.request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
  180. # LoRA inputs.
  181. self.lora_index_mapping: List[int] = []
  182. self.lora_prompt_mapping: List[int] = []
  183. self.lora_requests: Set[LoRARequest] = set()
  184. # Prompt adapter inputs.
  185. self.prompt_adapter_index_mapping: List[int] = []
  186. self.prompt_adapter_prompt_mapping: List[int] = []
  187. self.prompt_adapter_requests: Set[PromptAdapterRequest] = set()
  188. # Multi-modal inputs.
  189. self.multi_modal_inputs_list: List[MultiModalInputs] = []
  190. # Attention metadata inputs.
  191. self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
  192. self)
  193. # Engine/Model configurations.
  194. self.chunked_prefill_enabled = (
  195. self.scheduler_config is not None
  196. and self.scheduler_config.chunked_prefill_enabled)
  197. if self.sliding_window is not None:
  198. self.sliding_window_blocks = (
  199. self.sliding_window + self.block_size - 1) // self.block_size
  200. self.block_aligned_sliding_window = \
  201. self.sliding_window_blocks * self.block_size
  202. def _compute_len_for_sliding_window(self, seq_len: int):
  203. curr_sliding_window_blocks = 0
  204. sliding_seq_len = seq_len
  205. # TODO: This is a hack to make sliding window work with
  206. # paged attn. We can remove it if we make paged attn kernel
  207. # to properly handle slinding window attn.
  208. if self.sliding_window is not None:
  209. curr_sliding_window_blocks = self.sliding_window_blocks
  210. if self.scheduler_config.use_v2_block_manager:
  211. # number of elements in last block
  212. suff_len = seq_len % self.block_size
  213. sliding_seq_len = min(
  214. seq_len, self.block_aligned_sliding_window + suff_len)
  215. if suff_len > 0:
  216. curr_sliding_window_blocks += 1
  217. else:
  218. sliding_seq_len = min(seq_len, self.sliding_window)
  219. return curr_sliding_window_blocks, sliding_seq_len
  220. def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
  221. seq_ids = list(seq_group_metadata.seq_data.keys())
  222. n_seqs = len(seq_ids)
  223. is_prompt = seq_group_metadata.is_prompt
  224. token_chunk_size = seq_group_metadata.token_chunk_size
  225. if is_prompt:
  226. assert n_seqs == 1
  227. self.decode_only = False
  228. # Mapping from request IDs to sequence IDs. Used for Jamba models
  229. # that manages the cache by itself.
  230. self.request_ids_to_seq_ids[seq_group_metadata.request_id] = []
  231. # The number of input tokens in each sequence.
  232. token_lens: List[int] = []
  233. # The number of tokens that are already computed.
  234. context_lens: List[int] = []
  235. # The current sliding window block for each sequence.
  236. curr_sliding_window_blocks: List[int] = []
  237. # The original sequence length (before applying sliding window)
  238. # for each sequence.
  239. orig_seq_lens: List[int] = []
  240. # The sequence length (may be capped to the sliding window).
  241. curr_seq_lens: List[int] = []
  242. for seq_id in seq_ids:
  243. seq_data = seq_group_metadata.seq_data[seq_id]
  244. self.request_ids_to_seq_ids[seq_group_metadata.request_id].append(
  245. seq_id)
  246. computed_block_nums = seq_group_metadata.computed_block_nums
  247. # Check if hit prefix cache (i.e., some blocks are already computed)
  248. # Note that prefix caching does not support sliding window.
  249. prefix_cache_hit = (computed_block_nums is not None
  250. and len(computed_block_nums) > 0
  251. and self.sliding_window is None and is_prompt)
  252. if self.chunked_prefill_enabled and prefix_cache_hit:
  253. raise RuntimeError(
  254. "chunked prefill cannot be used with prefix caching now.")
  255. # Compute context length (the number of tokens that are
  256. # already computed) and sequence length (total number of tokens).
  257. seq_len = seq_data.get_len()
  258. if is_prompt:
  259. context_len = seq_data.get_num_computed_tokens()
  260. else:
  261. # get_num_computed_tokens is incorrect for spec decoding.
  262. # So, we should have a special logic here.
  263. # TODO: Fix it.
  264. context_len = seq_len - 1
  265. seq_len = min(seq_len, context_len + token_chunk_size)
  266. # Compute tokens.
  267. if is_prompt:
  268. tokens = seq_data.get_token_ids()[context_len:seq_len]
  269. else:
  270. # Optimization. get_token_ids requires the entire copy of
  271. # tokens.
  272. tokens = [seq_data.get_last_token_id()]
  273. if prefix_cache_hit:
  274. assert computed_block_nums is not None
  275. context_len = len(computed_block_nums) * self.block_size
  276. tokens = tokens[context_len:]
  277. # These are seq_len/context_len capped to the sliding window.
  278. # They are passed to decode kernel.
  279. # We still need original seq_len/context_len to compute slot
  280. # mapping (and input position) below.
  281. if is_prompt:
  282. curr_sliding_window_block = 0
  283. sliding_seq_len = seq_len
  284. query_len = seq_len - context_len
  285. else:
  286. curr_sliding_window_block, sliding_seq_len = (
  287. self._compute_len_for_sliding_window(seq_len))
  288. query_len = 1
  289. self.seq_lens.append(sliding_seq_len)
  290. if not is_prompt:
  291. self.max_decode_seq_len = max(self.max_decode_seq_len,
  292. sliding_seq_len)
  293. self.query_lens.append(query_len)
  294. self.input_tokens.extend(tokens)
  295. self.input_positions.extend(list(range(context_len, seq_len)))
  296. # Intermediate data of the current sequence group for
  297. # the attention metadata.
  298. token_lens.append(len(tokens))
  299. context_lens.append(context_len)
  300. curr_seq_lens.append(sliding_seq_len)
  301. curr_sliding_window_blocks.append(curr_sliding_window_block)
  302. orig_seq_lens.append(seq_len)
  303. # Update attention metadata. Note that input builder attributes
  304. # (self.xxx) include all added sequences, so we need to slice
  305. # the last n_seqs sequences.
  306. self.attn_metadata_builder.add_seq_group(
  307. seq_group_metadata, token_lens, orig_seq_lens, curr_seq_lens,
  308. self.query_lens[-n_seqs:], context_lens,
  309. curr_sliding_window_blocks, prefix_cache_hit,
  310. self.chunked_prefill_enabled)
  311. # LoRA data.
  312. if self.enable_lora:
  313. lora_id = seq_group_metadata.lora_int_id
  314. for query_len in self.query_lens[-n_seqs:]:
  315. if lora_id > 0:
  316. self.lora_requests.add(seq_group_metadata.lora_request)
  317. self.lora_index_mapping += [lora_id] * query_len
  318. self.lora_prompt_mapping.extend(
  319. [lora_id] *
  320. (query_len if seq_group_metadata.sampling_params
  321. and seq_group_metadata.sampling_params.prompt_logprobs
  322. is not None else 1))
  323. # Prompt adapter data. Note that when is_prompt=True,
  324. # we expect only one sequence in the group.
  325. if self.enable_prompt_adapter:
  326. prompt_adapter_id = seq_group_metadata.prompt_adapter_id
  327. if prompt_adapter_id > 0 and is_prompt:
  328. query_len = self.query_lens[-1]
  329. self.prompt_adapter_requests.add(
  330. seq_group_metadata.prompt_adapter_request)
  331. num_tokens = seq_group_metadata.\
  332. prompt_adapter_num_virtual_tokens
  333. pm = [prompt_adapter_id
  334. ] * num_tokens + [0] * (query_len - num_tokens)
  335. self.prompt_adapter_index_mapping += pm
  336. self.prompt_adapter_prompt_mapping.extend(
  337. [prompt_adapter_id] *
  338. (query_len if seq_group_metadata.sampling_params
  339. and seq_group_metadata.sampling_params.prompt_logprobs
  340. else 1))
  341. # Multi-modal data.
  342. mm_data = seq_group_metadata.multi_modal_data
  343. if mm_data:
  344. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  345. self.multi_modal_inputs_list.append(mm_kwargs)
  346. def build(self) -> ModelInputForGPU:
  347. if not self.input_tokens:
  348. return self.model_input_cls()
  349. batch_size = len(self.input_tokens)
  350. use_captured_graph = (
  351. self.decode_only and not self.runner.model_config.enforce_eager
  352. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  353. and self.max_decode_seq_len <= self.runner.max_seq_len_to_capture)
  354. # If cuda graph can be used, pad tensors accordingly.
  355. # See `capture_model` API for more details.
  356. # Aphrodite uses cuda graph only for decoding requests.
  357. cuda_graph_pad_size = -1
  358. if use_captured_graph:
  359. graph_batch_size = _get_graph_batch_size(batch_size)
  360. assert graph_batch_size >= batch_size
  361. cuda_graph_pad_size = graph_batch_size - batch_size
  362. batch_size = graph_batch_size
  363. # Tokens and positions.
  364. self.input_tokens.extend([0] * cuda_graph_pad_size)
  365. self.input_positions.extend([0] * cuda_graph_pad_size)
  366. input_tokens_tensor = torch.tensor(self.input_tokens,
  367. dtype=torch.long,
  368. device=self.runner.device)
  369. input_positions_tensor = torch.tensor(self.input_positions,
  370. dtype=torch.long,
  371. device=self.runner.device)
  372. # Sequence and query lengths.
  373. self.seq_lens.extend([1] * cuda_graph_pad_size)
  374. # Attention metadata.
  375. attn_metadata = self.attn_metadata_builder.build(
  376. self.runner, self.seq_lens, self.query_lens, cuda_graph_pad_size,
  377. batch_size)
  378. # LoRA data.
  379. if self.enable_lora:
  380. self.lora_index_mapping.extend([0] * cuda_graph_pad_size)
  381. lora_mapping = LoRAMapping(
  382. self.lora_index_mapping,
  383. self.lora_prompt_mapping,
  384. )
  385. else:
  386. lora_mapping = None
  387. # Prompt adapter data.
  388. if self.enable_prompt_adapter:
  389. self.prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
  390. prompt_adapter_mapping = PromptAdapterMapping(
  391. self.prompt_adapter_index_mapping,
  392. self.prompt_adapter_prompt_mapping,
  393. )
  394. else:
  395. prompt_adapter_mapping = None
  396. # Multi-modal data.
  397. multi_modal_kwargs = MultiModalInputs.batch(
  398. self.multi_modal_inputs_list, device=self.runner.device)
  399. return self.model_input_cls(
  400. input_tokens=input_tokens_tensor,
  401. input_positions=input_positions_tensor,
  402. attn_metadata=attn_metadata,
  403. seq_lens=self.seq_lens,
  404. query_lens=self.query_lens,
  405. lora_mapping=lora_mapping,
  406. lora_requests=self.lora_requests,
  407. multi_modal_kwargs=multi_modal_kwargs,
  408. request_ids_to_seq_ids=self.request_ids_to_seq_ids,
  409. finished_requests_ids=self.finished_requests_ids,
  410. prompt_adapter_mapping=prompt_adapter_mapping,
  411. prompt_adapter_requests=self.prompt_adapter_requests)
  412. class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
  413. """
  414. Helper class for shared methods between GPU model runners.
  415. """
  416. _model_input_cls: Type[TModelInputForGPU]
  417. def __init__(
  418. self,
  419. model_config: ModelConfig,
  420. parallel_config: ParallelConfig,
  421. scheduler_config: SchedulerConfig,
  422. device_config: DeviceConfig,
  423. cache_config: CacheConfig,
  424. load_config: LoadConfig,
  425. lora_config: Optional[LoRAConfig],
  426. kv_cache_dtype: Optional[str] = "auto",
  427. is_driver_worker: bool = False,
  428. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  429. multimodal_config: Optional[MultiModalConfig] = None,
  430. return_hidden_states: bool = False,
  431. tp_rank: int = 0,
  432. ):
  433. self.model_config = model_config
  434. self.parallel_config = parallel_config
  435. self.scheduler_config = scheduler_config
  436. self.device_config = device_config
  437. self.cache_config = cache_config
  438. self.lora_config = lora_config
  439. self.load_config = load_config
  440. self.is_driver_worker = is_driver_worker
  441. self.prompt_adapter_config = prompt_adapter_config
  442. self.multimodal_config = multimodal_config
  443. self.return_hidden_states = return_hidden_states
  444. self.device = self.device_config.device
  445. self.pin_memory = is_pin_memory_available()
  446. self.tp_rank = tp_rank
  447. self.kv_cache_dtype = kv_cache_dtype
  448. self.sliding_window = model_config.get_sliding_window()
  449. self.block_size = cache_config.block_size
  450. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  451. self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
  452. {} for _ in range(self.parallel_config.pipeline_parallel_size)
  453. ]
  454. self.graph_memory_pool: Optional[Tuple[
  455. int, int]] = None # Set during graph capture.
  456. self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
  457. parallel_config)
  458. # When using CUDA graph, the input block tables must be padded to
  459. # max_seq_len_to_capture. However, creating the block table in
  460. # Python can be expensive. To optimize this, we cache the block table
  461. # in numpy and only copy the actual input content at every iteration.
  462. # The shape of the cached block table will be
  463. # (max batch size to capture, max context len to capture / block size).
  464. self.graph_block_tables = np.zeros(
  465. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  466. dtype=np.int32)
  467. num_attn_heads = self.model_config.get_num_attention_heads(
  468. self.parallel_config, self.tp_rank)
  469. self.attn_backend = get_attn_backend(
  470. num_attn_heads,
  471. self.model_config.get_head_size(),
  472. self.model_config.get_num_kv_heads(self.parallel_config,
  473. self.tp_rank),
  474. self.model_config.get_sliding_window(),
  475. self.model_config.dtype,
  476. self.kv_cache_dtype,
  477. self.block_size,
  478. ) if num_attn_heads else None
  479. # Multi-modal data support
  480. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  481. .create_input_mapper(self.model_config)
  482. # Lazy initialization
  483. self.model: nn.Module # Set after load_model
  484. # Set after load_model.
  485. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  486. self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
  487. self.flashinfer_decode_workspace_buffer = None
  488. self.flashinfer_decode_wrapper = None
  489. self.flashinfer_prefill_workspace_buffer = None
  490. self.flashinfer_prefill_wrapper = None
  491. set_cpu_offload_max_bytes(
  492. int(self.cache_config.cpu_offload_gb * 1024**3))
  493. def load_model(self) -> None:
  494. with CudaMemoryProfiler() as m:
  495. # measure the time it takes to load the model
  496. start_time = time.time()
  497. self.model = get_model(model_config=self.model_config,
  498. device_config=self.device_config,
  499. load_config=self.load_config,
  500. lora_config=self.lora_config,
  501. multimodal_config=self.multimodal_config,
  502. parallel_config=self.parallel_config,
  503. scheduler_config=self.scheduler_config,
  504. cache_config=self.cache_config)
  505. end_time = time.time()
  506. self.model_memory_usage = m.consumed_memory
  507. tp = get_tensor_model_parallel_world_size()
  508. rank = get_tensor_model_parallel_rank()
  509. total_time = end_time - start_time
  510. if tp > 1:
  511. logger.info(
  512. f"Rank {rank}: Model weights loaded in {total_time:.2f} secs.")
  513. if rank == 0:
  514. logger.info(
  515. "Memory usage: "
  516. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} ="
  517. f" {self.model_memory_usage * tp / float(2**30):.2f} GiB")
  518. else:
  519. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  520. logger.info("Memory usage: "
  521. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  522. if self.lora_config:
  523. assert supports_lora(self.model), "Model does not support LoRA"
  524. assert not supports_vision(
  525. self.model
  526. ), "To be tested: vision language model with LoRA settings."
  527. self.lora_manager = LRUCacheWorkerLoRAManager(
  528. self.scheduler_config.max_num_seqs,
  529. self.scheduler_config.max_num_batched_tokens,
  530. self.vocab_size,
  531. self.lora_config,
  532. self.device,
  533. self.model.embedding_modules,
  534. self.model.embedding_padding_modules,
  535. max_position_embeddings=self.model.config.
  536. max_position_embeddings,
  537. )
  538. self.model = self.lora_manager.create_lora_manager(self.model)
  539. if self.prompt_adapter_config:
  540. self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
  541. self.scheduler_config.max_num_seqs,
  542. self.scheduler_config.max_num_batched_tokens, self.device,
  543. self.prompt_adapter_config)
  544. self.model = (
  545. self.prompt_adapter_manager.create_prompt_adapter_manager(
  546. self.model))
  547. if self.kv_cache_dtype == "fp8" and is_hip():
  548. # Currently only ROCm accepts kv-cache scaling factors
  549. # via quantization_param_path and this will be deprecated
  550. # in the future.
  551. if self.model_config.quantization_param_path is not None:
  552. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  553. warnings.warn(
  554. "Loading kv cache scaling factor from JSON is "
  555. "deprecated and will be removed. Please include "
  556. "kv cache scaling factors in the model checkpoint.",
  557. FutureWarning,
  558. stacklevel=2)
  559. self.model.load_kv_cache_scales(
  560. self.model_config.quantization_param_path)
  561. logger.info(
  562. "Loaded KV cache scaling factors from ",
  563. f"{self.model_config.quantization_param_path}")
  564. else:
  565. raise RuntimeError(
  566. "Using FP8 KV cache and scaling factors provided but "
  567. f"model {self.model.__class__} does not support loading"
  568. " scaling factors.", )
  569. else:
  570. logger.warning(
  571. "Using FP8 KV cache but no scaling factors "
  572. "provided. Defaulting to scaling factors of 1.0. "
  573. "This may lead to less accurate results!")
  574. def save_sharded_state(
  575. self,
  576. path: str,
  577. pattern: Optional[str] = None,
  578. max_size: Optional[int] = None,
  579. ) -> None:
  580. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  581. ShardedStateLoader.save_model(
  582. self.model,
  583. path,
  584. pattern=pattern,
  585. max_size=max_size,
  586. )
  587. def save_tensorized_model(
  588. self,
  589. tensorizer_config: TensorizerConfig,
  590. ) -> None:
  591. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  592. TensorizerLoader.save_model(
  593. self.model,
  594. tensorizer_config=tensorizer_config,
  595. )
  596. def get_max_block_per_batch(self) -> int:
  597. block_size = self.block_size
  598. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  599. def _prepare_model_input_tensors(
  600. self,
  601. seq_group_metadata_list: List[SequenceGroupMetadata],
  602. finished_requests_ids: Optional[List[str]] = None
  603. ) -> TModelInputForGPU:
  604. """Helper method to prepare the model input based on a given sequence
  605. group. Prepares metadata needed for the base model forward pass but not
  606. metadata for possible additional steps, e.g., sampling.
  607. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  608. The result tensors and data structure also batches input in prefill
  609. -> decode order. For example,
  610. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  611. - input_tokens[num_prefill_tokens:] contains decode tokens.
  612. If cuda graph is required, this API automatically pads inputs.
  613. """
  614. builder = ModelInputForGPUBuilder(weakref.proxy(self),
  615. finished_requests_ids)
  616. for seq_group_metadata in seq_group_metadata_list:
  617. builder.add_seq_group(seq_group_metadata)
  618. return builder.build() # type: ignore
  619. @torch.inference_mode()
  620. def profile_run(self) -> None:
  621. # Enable top-k sampling to reflect the accurate memory usage.
  622. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  623. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  624. max_num_seqs = self.scheduler_config.max_num_seqs
  625. # This represents the maximum number of different requests
  626. # that will have unique loras, an therefore the max amount of memory
  627. # consumption create dummy lora request copies from the lora request
  628. # passed in, which contains a lora from the lora warmup path.
  629. dummy_lora_requests: List[LoRARequest] = []
  630. dummy_lora_requests_per_seq: List[LoRARequest] = []
  631. if self.lora_config:
  632. assert self.lora_manager is not None
  633. with self.lora_manager.dummy_lora_cache():
  634. for idx in range(self.lora_config.max_loras):
  635. lora_id = idx + 1
  636. dummy_lora_request = LoRARequest(
  637. lora_name=f"warmup_{lora_id}",
  638. lora_int_id=lora_id,
  639. lora_local_path="/not/a/real/path",
  640. )
  641. self.lora_manager.add_dummy_lora(dummy_lora_request,
  642. rank=LORA_WARMUP_RANK)
  643. dummy_lora_requests.append(dummy_lora_request)
  644. dummy_lora_requests_per_seq = [
  645. dummy_lora_requests[idx % len(dummy_lora_requests)]
  646. for idx in range(max_num_seqs)
  647. ]
  648. # Profile memory usage with max_num_sequences sequences and the total
  649. # number of tokens equal to max_num_batched_tokens.
  650. seqs: List[SequenceGroupMetadata] = []
  651. # Additional GPU memory may be needed for vision encoding, which needs
  652. # to be accounted for when calculating the GPU blocks for
  653. # Aphrodite blocker manager.
  654. # To exercise the worst scenario for GPU memory consumption,
  655. # the number of seqs (batch_size) is chosen to maximize the number
  656. # of images processed.
  657. model_config = self.model_config
  658. if supports_vision(self.model):
  659. max_mm_tokens = MULTIMODAL_REGISTRY \
  660. .get_max_multimodal_tokens(model_config)
  661. max_num_seqs_orig = max_num_seqs
  662. max_num_seqs = min(max_num_seqs,
  663. max_num_batched_tokens // max_mm_tokens)
  664. if max_num_seqs < 1:
  665. expr = (f"min({max_num_seqs_orig}, "
  666. f"{max_num_batched_tokens} // {max_mm_tokens})")
  667. logger.warning(
  668. f"Computed max_num_seqs ({expr}) to be less than 1. "
  669. "Setting it to the minimum value of 1.")
  670. max_num_seqs = 1
  671. batch_size = 0
  672. for group_id in range(max_num_seqs):
  673. seq_len = (max_num_batched_tokens // max_num_seqs +
  674. (group_id < max_num_batched_tokens % max_num_seqs))
  675. batch_size += seq_len
  676. seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
  677. .dummy_data_for_profiling(model_config, seq_len)
  678. # Having more tokens is over-conservative but otherwise fine
  679. assert len(seq_data.prompt_token_ids) >= seq_len, (
  680. f"Expected at least {seq_len} dummy tokens for profiling, "
  681. f"but got: {len(seq_data.prompt_token_ids)}")
  682. seq = SequenceGroupMetadata(
  683. request_id=str(group_id),
  684. is_prompt=True,
  685. seq_data={group_id: seq_data},
  686. sampling_params=sampling_params,
  687. block_tables=None,
  688. lora_request=dummy_lora_requests_per_seq[group_id]
  689. if dummy_lora_requests_per_seq else None,
  690. multi_modal_data=dummy_multi_modal_data,
  691. )
  692. seqs.append(seq)
  693. # Run the model with the dummy inputs.
  694. num_layers = self.model_config.get_num_layers(self.parallel_config)
  695. kv_caches = [None] * num_layers
  696. finished_requests_ids = [seq.request_id for seq in seqs]
  697. model_input = self.prepare_model_input(
  698. seqs, finished_requests_ids=finished_requests_ids)
  699. intermediate_tensors = None
  700. if not get_pp_group().is_first_rank:
  701. intermediate_tensors = self.model.make_empty_intermediate_tensors(
  702. batch_size=batch_size,
  703. dtype=self.model_config.dtype,
  704. device=self.device)
  705. self.execute_model(model_input, kv_caches, intermediate_tensors)
  706. torch.cuda.synchronize()
  707. return
  708. def remove_all_loras(self):
  709. if not self.lora_manager:
  710. raise RuntimeError("LoRA is not enabled.")
  711. self.lora_manager.remove_all_adapters()
  712. def set_active_loras(self, lora_requests: Set[LoRARequest],
  713. lora_mapping: LoRAMapping) -> None:
  714. if not self.lora_manager:
  715. raise RuntimeError("LoRA is not enabled.")
  716. self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
  717. def add_lora(self, lora_request: LoRARequest) -> bool:
  718. if not self.lora_manager:
  719. raise RuntimeError("LoRA is not enabled.")
  720. return self.lora_manager.add_adapter(lora_request)
  721. def remove_lora(self, lora_id: int) -> bool:
  722. if not self.lora_manager:
  723. raise RuntimeError("LoRA is not enabled.")
  724. return self.lora_manager.remove_adapter(lora_id)
  725. def pin_lora(self, lora_id: int) -> bool:
  726. if not self.lora_manager:
  727. raise RuntimeError("LoRA is not enabled.")
  728. return self.lora_manager.pin_adapter(lora_id)
  729. def list_loras(self) -> Set[int]:
  730. if not self.lora_manager:
  731. raise RuntimeError("LoRA is not enabled.")
  732. return self.lora_manager.list_adapters()
  733. def remove_all_prompt_adapters(self):
  734. if not self.prompt_adapter_manager:
  735. raise RuntimeError("PromptAdapter is not enabled.")
  736. self.prompt_adapter_manager.remove_all_adapters()
  737. def set_active_prompt_adapters(
  738. self, prompt_adapter_requests: Set[PromptAdapterRequest],
  739. prompt_adapter_mapping: PromptAdapterMapping) -> None:
  740. if not self.prompt_adapter_manager:
  741. raise RuntimeError("PromptAdapter is not enabled.")
  742. self.prompt_adapter_manager.set_active_adapters(
  743. prompt_adapter_requests, prompt_adapter_mapping)
  744. def add_prompt_adapter(
  745. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  746. if not self.prompt_adapter_manager:
  747. raise RuntimeError("PromptAdapter is not enabled.")
  748. return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
  749. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  750. if not self.prompt_adapter_manager:
  751. raise RuntimeError("PromptAdapter is not enabled.")
  752. return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
  753. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  754. if not self.prompt_adapter_manager:
  755. raise RuntimeError("PromptAdapter is not enabled.")
  756. return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
  757. def list_prompt_adapters(self) -> Set[int]:
  758. if not self.prompt_adapter_manager:
  759. raise RuntimeError("PromptAdapter is not enabled.")
  760. return self.prompt_adapter_manager.list_adapters()
  761. @torch.inference_mode()
  762. def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
  763. """Cuda graph capture a model.
  764. Note that CUDA graph's performance gain is negligible if number
  765. of batched tokens are larger than 200. And since CUDA graph
  766. requires fixed sized tensors, supporting large/variable batch
  767. size requires high GPU memory overhead. Thus, Aphrodite only captures
  768. decoding requests. Mixed batch (chunked prefill + decoding) or
  769. prefill requests are not captured.
  770. Since it is used for decoding-only, it assumes there's only 1 token
  771. per sequence in the batch.
  772. """
  773. assert not self.model_config.enforce_eager
  774. logger.info("Capturing the model for CUDA graphs. This may lead to "
  775. "unexpected consequences if the model is not static. To "
  776. "run the model in eager mode, set 'enforce_eager=True' or "
  777. "use '--enforce-eager' in the CLI.")
  778. logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
  779. "If you are running out of memory, consider decreasing "
  780. "`gpu_memory_utilization` or enforcing eager mode. "
  781. "You can also reduce the `max_num_seqs` as needed "
  782. "to decrease memory usage.")
  783. start_time = time.perf_counter()
  784. # Prepare dummy inputs. These will be reused for all batch sizes.
  785. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  786. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  787. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  788. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  789. slot_mapping.fill_(_PAD_SLOT_ID)
  790. seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  791. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  792. intermediate_inputs = None
  793. if not get_pp_group().is_first_rank:
  794. intermediate_inputs = self.model.make_empty_intermediate_tensors(
  795. batch_size=max_batch_size,
  796. dtype=self.model_config.dtype,
  797. device=self.device)
  798. # Prepare buffer for outputs. These will be reused for all batch sizes.
  799. # It will be filled after the first graph capture.
  800. hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
  801. None
  802. ] * self.parallel_config.pipeline_parallel_size
  803. graph_batch_size = _get_graph_batch_size(
  804. self.scheduler_config.max_num_seqs)
  805. batch_size_capture_list = [
  806. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  807. ]
  808. if self.attn_backend.get_name() == "flashinfer":
  809. # For flashinfer, different batch sizes will share the
  810. # same workspace buffer.
  811. decode_workspace_buffer = \
  812. torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
  813. dtype=torch.uint8,
  814. device=self.device)
  815. indices_buffer = torch.empty(max_batch_size *
  816. self.cache_config.num_gpu_blocks,
  817. dtype=torch.int32,
  818. device=self.device)
  819. indptr_buffer = torch.empty(max_batch_size + 1,
  820. dtype=torch.int32,
  821. device=self.device)
  822. last_page_len_buffer = torch.empty(max_batch_size,
  823. dtype=torch.int32,
  824. device=self.device)
  825. with graph_capture() as graph_capture_context:
  826. # NOTE: Capturing the largest batch size first may help reduce the
  827. # memory usage of CUDA graph.
  828. for virtual_engine in range(
  829. self.parallel_config.pipeline_parallel_size):
  830. for batch_size in reversed(batch_size_capture_list):
  831. if self.attn_backend.get_name() == "flashinfer":
  832. indptr_buffer = indptr_buffer[:batch_size + 1]
  833. last_page_len_buffer = last_page_len_buffer[:
  834. batch_size]
  835. num_qo_heads = (
  836. self.model_config.get_num_attention_heads(
  837. self.parallel_config), self.tp_rank)
  838. num_kv_heads = self.model_config.get_num_kv_heads(
  839. self.parallel_config, self.tp_rank)
  840. if num_qo_heads // num_kv_heads >= 4:
  841. use_tensor_cores = True
  842. else:
  843. use_tensor_cores = False
  844. decode_wrapper = \
  845. CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
  846. decode_workspace_buffer, indptr_buffer,
  847. indices_buffer, last_page_len_buffer, "NHD",
  848. use_tensor_cores)
  849. kv_cache_dtype = get_kv_cache_torch_dtype(
  850. self.kv_cache_dtype, self.model_config.dtype)
  851. paged_kv_indptr_tensor_host = torch.arange(
  852. 0, batch_size + 1, dtype=torch.int32)
  853. paged_kv_indices_tensor_host = torch.arange(
  854. 0, batch_size, dtype=torch.int32)
  855. paged_kv_last_page_len_tensor_host = torch.full(
  856. (batch_size, ), self.block_size, dtype=torch.int32)
  857. query_start_loc_host = torch.arange(0,
  858. batch_size + 1,
  859. dtype=torch.int32)
  860. attn_metadata = self.attn_backend.make_metadata(
  861. num_prefills=0,
  862. slot_mapping=slot_mapping[:batch_size],
  863. num_prefill_tokens=0,
  864. num_decode_tokens=batch_size,
  865. max_prefill_seq_len=0,
  866. block_tables=block_tables,
  867. paged_kv_indptr=paged_kv_indptr_tensor_host,
  868. paged_kv_indices=paged_kv_indices_tensor_host,
  869. paged_kv_last_page_len=
  870. paged_kv_last_page_len_tensor_host,
  871. num_qo_heads=num_qo_heads,
  872. num_kv_heads=num_kv_heads,
  873. head_dim=self.model_config.get_head_size(),
  874. page_size=self.block_size,
  875. seq_start_loc=None,
  876. query_start_loc=query_start_loc_host,
  877. device=self.device,
  878. data_type=kv_cache_dtype,
  879. use_cuda_graph=True,
  880. decode_wrapper=decode_wrapper,
  881. prefill_wrapper=None)
  882. attn_metadata.begin_forward()
  883. else:
  884. attn_metadata = self.attn_backend.make_metadata(
  885. num_prefills=0,
  886. num_prefill_tokens=0,
  887. num_decode_tokens=batch_size,
  888. slot_mapping=slot_mapping[:batch_size],
  889. seq_lens=None,
  890. seq_lens_tensor=seq_lens[:batch_size],
  891. max_query_len=None,
  892. max_prefill_seq_len=0,
  893. max_decode_seq_len=self.max_seq_len_to_capture,
  894. query_start_loc=None,
  895. seq_start_loc=None,
  896. context_lens_tensor=None,
  897. block_tables=block_tables[:batch_size],
  898. use_cuda_graph=True,
  899. )
  900. if self.lora_config:
  901. lora_mapping = LoRAMapping(
  902. [0] * batch_size,
  903. [0] * batch_size,
  904. )
  905. self.set_active_loras(set(), lora_mapping)
  906. if self.prompt_adapter_config:
  907. prompt_adapter_mapping = PromptAdapterMapping(
  908. [-1] * batch_size,
  909. [-1] * batch_size,
  910. )
  911. self.set_active_prompt_adapters(
  912. set(), prompt_adapter_mapping)
  913. graph_runner = CUDAGraphRunner(
  914. self.model, self.attn_backend.get_name())
  915. if self.attn_backend.get_name() == "flashinfer":
  916. graph_runner.flashinfer_indptr_buffer = indptr_buffer
  917. graph_runner.flashinfer_indices_buffer = indices_buffer
  918. graph_runner.flashinfer_last_page_len_buffer = \
  919. last_page_len_buffer
  920. graph_runner.flashinfer_decode_workspace_buffer = \
  921. decode_workspace_buffer
  922. graph_runner.flashinfer_decode_wrapper = \
  923. decode_wrapper
  924. capture_inputs = {
  925. "input_ids":
  926. input_tokens[:batch_size],
  927. "positions":
  928. input_positions[:batch_size],
  929. "hidden_or_intermediate_states":
  930. hidden_or_intermediate_states[
  931. virtual_engine] # type: ignore
  932. [:batch_size]
  933. if hidden_or_intermediate_states[virtual_engine]
  934. is not None else None,
  935. "intermediate_inputs":
  936. intermediate_inputs[:batch_size]
  937. if intermediate_inputs is not None else None,
  938. "kv_caches":
  939. kv_caches[virtual_engine],
  940. "attn_metadata":
  941. attn_metadata,
  942. "memory_pool":
  943. self.graph_memory_pool,
  944. "stream":
  945. graph_capture_context.stream
  946. }
  947. if self.has_seqlen_agnostic:
  948. # Only used by Mamba-based models CUDA graph atm (Jamba)
  949. capture_inputs.update({
  950. "seqlen_agnostic_capture_inputs":
  951. self.model.get_seqlen_agnostic_capture_inputs(
  952. batch_size)
  953. })
  954. graph_runner.capture(**capture_inputs)
  955. self.graph_memory_pool = graph_runner.graph.pool()
  956. self.graph_runners[virtual_engine][batch_size] = (
  957. graph_runner)
  958. end_time = time.perf_counter()
  959. elapsed_time = end_time - start_time
  960. # This usually takes < 10 seconds.
  961. logger.info(f"Graph capturing finished in {elapsed_time:2f} secs.")
  962. @property
  963. def vocab_size(self) -> int:
  964. return self.model_config.get_vocab_size()
  965. class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
  966. """
  967. GPU model runner with sampling step.
  968. """
  969. _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
  970. ModelInputForGPUWithSamplingMetadata)
  971. def make_model_input_from_broadcasted_tensor_dict(
  972. self,
  973. tensor_dict: Dict[str, Any],
  974. ) -> ModelInputForGPUWithSamplingMetadata:
  975. model_input = \
  976. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  977. tensor_dict,
  978. attn_backend=self.attn_backend,
  979. )
  980. return model_input
  981. def prepare_model_input(
  982. self,
  983. seq_group_metadata_list: List[SequenceGroupMetadata],
  984. virtual_engine: int = 0,
  985. finished_requests_ids: Optional[List[str]] = None
  986. ) -> ModelInputForGPUWithSamplingMetadata:
  987. """Prepare the model input based on a given sequence group, including
  988. metadata for the sampling step.
  989. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  990. The result tensors and data structure also batches input in prefill
  991. -> decode order. For example,
  992. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  993. - input_tokens[num_prefill_tokens:] contains decode tokens.
  994. If cuda graph is required, this API automatically pads inputs.
  995. """
  996. model_input = self._prepare_model_input_tensors(
  997. seq_group_metadata_list, finished_requests_ids)
  998. sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
  999. model_input.seq_lens,
  1000. model_input.query_lens,
  1001. self.device,
  1002. self.pin_memory)
  1003. is_prompt = (seq_group_metadata_list[0].is_prompt
  1004. if seq_group_metadata_list else None)
  1005. return dataclasses.replace(model_input,
  1006. sampling_metadata=sampling_metadata,
  1007. is_prompt=is_prompt,
  1008. virtual_engine=virtual_engine)
  1009. @torch.inference_mode()
  1010. def execute_model(
  1011. self,
  1012. model_input: ModelInputForGPUWithSamplingMetadata,
  1013. kv_caches: List[torch.Tensor],
  1014. intermediate_tensors: Optional[IntermediateTensors] = None,
  1015. num_steps: int = 1,
  1016. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  1017. if num_steps > 1:
  1018. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  1019. if self.lora_config:
  1020. assert model_input.lora_requests is not None
  1021. assert model_input.lora_mapping is not None
  1022. self.set_active_loras(model_input.lora_requests,
  1023. model_input.lora_mapping)
  1024. if self.prompt_adapter_config:
  1025. assert model_input.prompt_adapter_requests is not None
  1026. assert model_input.prompt_adapter_mapping is not None
  1027. self.set_active_prompt_adapters(
  1028. model_input.prompt_adapter_requests,
  1029. model_input.prompt_adapter_mapping)
  1030. if self.attn_backend.get_name() == "flashinfer":
  1031. assert model_input.attn_metadata is not None
  1032. assert model_input.input_tokens is not None
  1033. if self.flashinfer_decode_workspace_buffer is None:
  1034. self.flashinfer_decode_workspace_buffer = torch.empty(
  1035. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  1036. dtype=torch.uint8,
  1037. device=self.device)
  1038. self.flashinfer_decode_wrapper = \
  1039. BatchDecodeWithPagedKVCacheWrapper(
  1040. self.flashinfer_decode_workspace_buffer, "NHD")
  1041. self.flashinfer_prefill_workspace_buffer = torch.empty(
  1042. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  1043. dtype=torch.uint8,
  1044. device=self.device)
  1045. self.flashinfer_prefill_wrapper = \
  1046. BatchPrefillWithPagedKVCacheWrapper(
  1047. self.flashinfer_prefill_workspace_buffer, "NHD")
  1048. model_input.attn_metadata.prefill_wrapper = \
  1049. self.flashinfer_prefill_wrapper
  1050. if model_input.attn_metadata.use_cuda_graph:
  1051. batch_size = model_input.input_tokens.shape[0]
  1052. model_input.attn_metadata.decode_wrapper = self.graph_runners[
  1053. model_input.
  1054. virtual_engine][batch_size].flashinfer_decode_wrapper
  1055. else:
  1056. model_input.attn_metadata.decode_wrapper = \
  1057. self.flashinfer_decode_wrapper
  1058. model_input.attn_metadata.begin_forward()
  1059. # Currently cuda graph is only supported by the decode phase.
  1060. assert model_input.attn_metadata is not None
  1061. prefill_meta = model_input.attn_metadata.prefill_metadata
  1062. decode_meta = model_input.attn_metadata.decode_metadata
  1063. # TODO: We can remove this once all
  1064. # virtual engines share the same kv cache.
  1065. virtual_engine = model_input.virtual_engine
  1066. if prefill_meta is None and decode_meta.use_cuda_graph:
  1067. assert model_input.input_tokens is not None
  1068. graph_batch_size = model_input.input_tokens.shape[0]
  1069. model_executable = self.graph_runners[virtual_engine][
  1070. graph_batch_size]
  1071. else:
  1072. model_executable = self.model
  1073. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  1074. seqlen_agnostic_kwargs = {
  1075. "finished_requests_ids": model_input.finished_requests_ids,
  1076. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  1077. } if self.has_seqlen_agnostic else {}
  1078. hidden_or_intermediate_states = model_executable(
  1079. input_ids=model_input.input_tokens,
  1080. positions=model_input.input_positions,
  1081. kv_caches=kv_caches,
  1082. attn_metadata=model_input.attn_metadata,
  1083. intermediate_tensors=intermediate_tensors,
  1084. **multi_modal_kwargs,
  1085. **seqlen_agnostic_kwargs,
  1086. )
  1087. # Compute the logits in the last pipeline stage.
  1088. if not get_pp_group().is_last_rank:
  1089. return hidden_or_intermediate_states
  1090. logits = self.model.compute_logits(hidden_or_intermediate_states,
  1091. model_input.sampling_metadata)
  1092. if not self.is_driver_worker:
  1093. return []
  1094. # Sample the next token.
  1095. output: SamplerOutput = self.model.sample(
  1096. logits=logits,
  1097. sampling_metadata=model_input.sampling_metadata,
  1098. )
  1099. if self.return_hidden_states:
  1100. # we only need to pass hidden states of most recent token
  1101. assert model_input.sampling_metadata is not None
  1102. indices = model_input.sampling_metadata.selected_token_indices
  1103. if model_input.is_prompt:
  1104. hidden_states = hidden_or_intermediate_states.index_select(
  1105. 0, indices)
  1106. elif decode_meta.use_cuda_graph:
  1107. hidden_states = hidden_or_intermediate_states[:len(indices)]
  1108. else:
  1109. hidden_states = hidden_or_intermediate_states
  1110. output.hidden_states = hidden_states
  1111. return [output]
  1112. class CUDAGraphRunner:
  1113. def __init__(self, model: nn.Module, backend_name: str):
  1114. self.model = model
  1115. self.backend_name = backend_name
  1116. self.input_buffers: Dict[str, torch.Tensor] = {}
  1117. self.output_buffers: Dict[str, torch.Tensor] = {}
  1118. self._graph: Optional[torch.cuda.CUDAGraph] = None
  1119. self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
  1120. self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
  1121. self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
  1122. self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
  1123. self.flashinfer_decode_wrapper: Optional[
  1124. CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
  1125. @property
  1126. def graph(self):
  1127. assert self._graph is not None
  1128. return self._graph
  1129. def capture(
  1130. self,
  1131. input_ids: torch.Tensor,
  1132. positions: torch.Tensor,
  1133. hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
  1134. torch.Tensor]],
  1135. intermediate_inputs: Optional[IntermediateTensors],
  1136. kv_caches: List[torch.Tensor],
  1137. attn_metadata: AttentionMetadata,
  1138. memory_pool: Optional[Tuple[int, int]],
  1139. stream: torch.cuda.Stream,
  1140. **kwargs,
  1141. ) -> Union[torch.Tensor, IntermediateTensors]:
  1142. assert self._graph is None
  1143. # Run the model a few times without capturing the graph.
  1144. # This is to make sure that the captured graph does not include the
  1145. # kernel launches for initial benchmarking (e.g., Triton autotune).
  1146. # Note one iteration is not enough for torch.jit.script
  1147. for _ in range(_NUM_WARMUP_ITERS):
  1148. self.model(
  1149. input_ids,
  1150. positions,
  1151. kv_caches,
  1152. attn_metadata,
  1153. intermediate_inputs,
  1154. **kwargs,
  1155. )
  1156. torch.cuda.synchronize()
  1157. # Capture the graph.
  1158. self._graph = torch.cuda.CUDAGraph()
  1159. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  1160. output_hidden_or_intermediate_states = self.model(
  1161. input_ids,
  1162. positions,
  1163. kv_caches,
  1164. attn_metadata,
  1165. intermediate_inputs,
  1166. **kwargs,
  1167. )
  1168. if hidden_or_intermediate_states is not None:
  1169. if get_pp_group().is_last_rank:
  1170. hidden_or_intermediate_states.copy_(
  1171. output_hidden_or_intermediate_states)
  1172. else:
  1173. for key in hidden_or_intermediate_states.tensors:
  1174. hidden_or_intermediate_states[key].copy_(
  1175. output_hidden_or_intermediate_states[key])
  1176. else:
  1177. hidden_or_intermediate_states = (
  1178. output_hidden_or_intermediate_states)
  1179. del output_hidden_or_intermediate_states
  1180. # make sure `output_hidden_states` is deleted
  1181. # in the graph's memory pool
  1182. gc.collect()
  1183. torch.cuda.synchronize()
  1184. # Save the input and output buffers.
  1185. if self.backend_name == "flashinfer":
  1186. self.input_buffers = {
  1187. "input_ids": input_ids,
  1188. "positions": positions,
  1189. "kv_caches": kv_caches,
  1190. "slot_mapping": attn_metadata.slot_mapping,
  1191. **kwargs,
  1192. }
  1193. else:
  1194. self.input_buffers = {
  1195. "input_ids": input_ids,
  1196. "positions": positions,
  1197. "kv_caches": kv_caches,
  1198. "slot_mapping": attn_metadata.slot_mapping,
  1199. "seq_lens_tensor":
  1200. attn_metadata.decode_metadata.seq_lens_tensor,
  1201. "block_tables": attn_metadata.decode_metadata.block_tables,
  1202. **kwargs,
  1203. }
  1204. if intermediate_inputs is not None:
  1205. self.input_buffers.update(intermediate_inputs.tensors)
  1206. if get_pp_group().is_last_rank:
  1207. self.output_buffers = {
  1208. "hidden_states": hidden_or_intermediate_states
  1209. }
  1210. else:
  1211. self.output_buffers = hidden_or_intermediate_states
  1212. return hidden_or_intermediate_states
  1213. def forward(
  1214. self,
  1215. input_ids: torch.Tensor,
  1216. positions: torch.Tensor,
  1217. kv_caches: List[torch.Tensor],
  1218. attn_metadata: AttentionMetadata,
  1219. intermediate_tensors: Optional[IntermediateTensors],
  1220. **kwargs,
  1221. ) -> torch.Tensor:
  1222. # KV caches are fixed tensors, so we don't need to copy them.
  1223. del kv_caches
  1224. # Copy the input tensors to the input buffers.
  1225. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1226. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1227. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  1228. non_blocking=True)
  1229. if self.backend_name != "flashinfer":
  1230. self.input_buffers["seq_lens_tensor"].copy_(
  1231. attn_metadata.decode_metadata.seq_lens_tensor,
  1232. non_blocking=True)
  1233. self.input_buffers["block_tables"].copy_(
  1234. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  1235. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1236. self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
  1237. **kwargs)
  1238. if intermediate_tensors is not None:
  1239. for key in intermediate_tensors.tensors:
  1240. self.input_buffers[key].copy_(intermediate_tensors[key],
  1241. non_blocking=True)
  1242. # Run the graph.
  1243. self.graph.replay()
  1244. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1245. self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
  1246. **kwargs)
  1247. # Return the output tensor.
  1248. if get_pp_group().is_last_rank:
  1249. return self.output_buffers["hidden_states"]
  1250. return self.output_buffers
  1251. def __call__(self, *args, **kwargs):
  1252. return self.forward(*args, **kwargs)
  1253. def _get_graph_batch_size(batch_size: int) -> int:
  1254. """Returns the padded batch size given actual batch size.
  1255. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1256. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1257. """
  1258. if batch_size <= 2:
  1259. return batch_size
  1260. elif batch_size <= 4:
  1261. return 4
  1262. else:
  1263. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1264. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  1265. def _is_block_tables_empty(block_tables: Union[None, Dict]):
  1266. """
  1267. Check if block_tables is None or a dictionary with all None values.
  1268. """
  1269. if block_tables is None:
  1270. return True
  1271. if isinstance(block_tables, dict) and all(
  1272. value is None for value in block_tables.values()):
  1273. return True
  1274. return False