model_runner.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610
  1. import dataclasses
  2. import gc
  3. import os
  4. import time
  5. import warnings
  6. import weakref
  7. from dataclasses import dataclass
  8. from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
  9. TypeVar, Union)
  10. import numpy as np
  11. import torch
  12. import torch.distributed
  13. import torch.nn as nn
  14. from loguru import logger
  15. try:
  16. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  17. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  18. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  19. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  20. except ImportError:
  21. BatchDecodeWithPagedKVCacheWrapper = None
  22. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  23. BatchPrefillWithPagedKVCacheWrapper = None
  24. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  25. from aphrodite.attention import AttentionMetadata, get_attn_backend
  26. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  27. LoRAConfig, ModelConfig, MultiModalConfig,
  28. ParallelConfig, PromptAdapterConfig,
  29. SchedulerConfig)
  30. from aphrodite.common.sampling_params import SamplingParams
  31. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  32. SequenceGroupMetadata)
  33. from aphrodite.common.utils import (CudaMemoryProfiler, async_tensor_h2d,
  34. flatten_2d_lists, get_kv_cache_torch_dtype,
  35. is_hip, is_pin_memory_available)
  36. from aphrodite.distributed import get_pp_group
  37. from aphrodite.distributed.parallel_state import (
  38. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
  39. graph_capture)
  40. from aphrodite.inputs import INPUT_REGISTRY
  41. from aphrodite.lora.layers import LoRAMapping
  42. from aphrodite.lora.request import LoRARequest
  43. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  44. from aphrodite.modeling import SamplingMetadata
  45. from aphrodite.modeling.model_loader import get_model
  46. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  47. from aphrodite.modeling.models.interfaces import supports_lora, supports_vision
  48. from aphrodite.modeling.models.utils import set_cpu_offload_max_bytes
  49. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  50. MultiModalInputs)
  51. from aphrodite.prompt_adapter.layers import PromptAdapterMapping
  52. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  53. from aphrodite.prompt_adapter.worker_manager import (
  54. LRUCacheWorkerPromptAdapterManager)
  55. from aphrodite.task_handler.model_runner_base import (
  56. ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
  57. _add_attn_metadata_broadcastable_dict,
  58. _add_sampling_metadata_broadcastable_dict,
  59. _init_attn_metadata_from_tensor_dict,
  60. _init_sampling_metadata_from_tensor_dict)
  61. if TYPE_CHECKING:
  62. from aphrodite.attention.backends.abstract import AttentionBackend
  63. _PAD_SLOT_ID = -1
  64. LORA_WARMUP_RANK = 8
  65. _BATCH_SIZE_ALIGNMENT = 8
  66. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  67. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  68. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  69. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  70. ]
  71. _NUM_WARMUP_ITERS = 2
  72. APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE = int(
  73. os.environ.get("APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE", "0"))
  74. TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
  75. @dataclass(frozen=True)
  76. class ModelInputForGPU(ModelRunnerInputBase):
  77. """
  78. This base class contains metadata needed for the base model forward pass
  79. but not metadata for possible additional steps, e.g., sampling. Model
  80. runners that run additional steps should subclass this method to add
  81. additional fields.
  82. """
  83. input_tokens: Optional[torch.Tensor] = None
  84. input_positions: Optional[torch.Tensor] = None
  85. seq_lens: Optional[List[int]] = None
  86. query_lens: Optional[List[int]] = None
  87. lora_mapping: Optional["LoRAMapping"] = None
  88. lora_requests: Optional[Set[LoRARequest]] = None
  89. attn_metadata: Optional["AttentionMetadata"] = None
  90. prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
  91. prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
  92. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  93. request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
  94. finished_requests_ids: Optional[List[str]] = None
  95. virtual_engine: int = 0
  96. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  97. tensor_dict = {
  98. "input_tokens": self.input_tokens,
  99. "input_positions": self.input_positions,
  100. "lora_requests": self.lora_requests,
  101. "lora_mapping": self.lora_mapping,
  102. "multi_modal_kwargs": self.multi_modal_kwargs,
  103. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  104. "prompt_adapter_requests": self.prompt_adapter_requests,
  105. "virtual_engine": self.virtual_engine,
  106. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  107. "finished_requests_ids": self.finished_requests_ids,
  108. }
  109. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  110. return tensor_dict
  111. @classmethod
  112. def from_broadcasted_tensor_dict(
  113. cls: Type[TModelInputForGPU],
  114. tensor_dict: Dict[str, Any],
  115. attn_backend: Optional["AttentionBackend"] = None,
  116. ) -> TModelInputForGPU:
  117. if attn_backend is not None:
  118. tensor_dict = _init_attn_metadata_from_tensor_dict(
  119. attn_backend, tensor_dict)
  120. return cls(**tensor_dict)
  121. @dataclass(frozen=True)
  122. class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
  123. """
  124. Used by the ModelRunner.
  125. """
  126. sampling_metadata: Optional["SamplingMetadata"] = None
  127. # Used for speculative decoding. We do not broadcast it because it is only
  128. # used by the driver worker.
  129. is_prompt: Optional[bool] = None
  130. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  131. tensor_dict = {
  132. "input_tokens": self.input_tokens,
  133. "input_positions": self.input_positions,
  134. "lora_requests": self.lora_requests,
  135. "lora_mapping": self.lora_mapping,
  136. "multi_modal_kwargs": self.multi_modal_kwargs,
  137. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  138. "prompt_adapter_requests": self.prompt_adapter_requests,
  139. "virtual_engine": self.virtual_engine,
  140. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  141. "finished_requests_ids": self.finished_requests_ids,
  142. }
  143. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  144. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  145. self.sampling_metadata)
  146. return tensor_dict
  147. @classmethod
  148. def from_broadcasted_tensor_dict(
  149. cls,
  150. tensor_dict: Dict[str, Any],
  151. attn_backend: Optional["AttentionBackend"] = None,
  152. ) -> "ModelInputForGPUWithSamplingMetadata":
  153. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  154. if attn_backend is not None:
  155. tensor_dict = _init_attn_metadata_from_tensor_dict(
  156. attn_backend, tensor_dict)
  157. return cls(**tensor_dict)
  158. class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
  159. """Build ModelInputForGPU from SequenceGroupMetadata."""
  160. # NOTE: ideally, we dould be using a dataclass(kw_only=True)
  161. # here, so that this can be subclassed easily, but kw_only
  162. # is not supported in python<3.10.
  163. class InterDataForSeqGroup:
  164. def __init__(
  165. self,
  166. *,
  167. # From sequence group metadata.
  168. request_id: str,
  169. seq_ids: List[int],
  170. is_prompt: bool,
  171. block_tables: Optional[Dict[int, List[int]]],
  172. computed_block_nums: List[int],
  173. n_seqs: int = 0,
  174. # Input tokens and positions.
  175. input_tokens: Optional[List[List[int]]] = None,
  176. input_positions: Optional[List[List[int]]] = None,
  177. # The sequence length (may be capped to the sliding window).
  178. seq_lens: Optional[List[int]] = None,
  179. # The original sequence length (before applying sliding window).
  180. # This is used to compute slot mapping.
  181. orig_seq_lens: Optional[List[int]] = None,
  182. # The query length.
  183. query_lens: Optional[List[int]] = None,
  184. # The number of tokens that are already computed.
  185. context_lens: Optional[List[int]] = None,
  186. # The current sliding window block.
  187. curr_sliding_window_blocks: Optional[List[int]] = None,
  188. # LoRA inputs.
  189. lora_index_mapping: Optional[List[List[int]]] = None,
  190. lora_prompt_mapping: Optional[List[List[int]]] = None,
  191. lora_requests: Optional[Set[LoRARequest]] = None,
  192. # Prompt adapter inputs.
  193. prompt_adapter_index_mapping: Optional[List[int]] = None,
  194. prompt_adapter_prompt_mapping: Optional[List[int]] = None,
  195. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  196. # Multi-modal inputs.
  197. multi_modal_inputs: Optional[MultiModalInputs] = None,
  198. # Whether the prefix cache is hit (prefill only).
  199. prefix_cache_hit: bool = False,
  200. ):
  201. self.request_id = request_id
  202. self.seq_ids = seq_ids
  203. self.is_prompt = is_prompt
  204. self.block_tables = block_tables
  205. self.computed_block_nums = computed_block_nums
  206. self.n_seqs = n_seqs
  207. self.input_tokens = input_tokens or []
  208. self.input_positions = input_positions or []
  209. self.seq_lens = seq_lens or []
  210. self.orig_seq_lens = orig_seq_lens or []
  211. self.query_lens = query_lens or []
  212. self.context_lens = context_lens or []
  213. self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
  214. self.lora_index_mapping = lora_index_mapping or []
  215. self.lora_prompt_mapping = lora_prompt_mapping or []
  216. self.lora_requests = lora_requests or set()
  217. self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
  218. or [])
  219. self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
  220. or [])
  221. self.prompt_adapter_request = prompt_adapter_request
  222. self.multi_modal_inputs = multi_modal_inputs
  223. self.prefix_cache_hit = prefix_cache_hit
  224. self.__post_init__()
  225. def __post_init__(self):
  226. self.n_seqs = len(self.seq_ids)
  227. self.input_tokens = [[] for _ in range(self.n_seqs)]
  228. self.input_positions = [[] for _ in range(self.n_seqs)]
  229. self.seq_lens = [0] * self.n_seqs
  230. self.orig_seq_lens = [0] * self.n_seqs
  231. self.query_lens = [0] * self.n_seqs
  232. self.context_lens = [0] * self.n_seqs
  233. self.curr_sliding_window_blocks = [0] * self.n_seqs
  234. self.lora_index_mapping = [[] for _ in range(self.n_seqs)]
  235. self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)]
  236. def __init__(self,
  237. runner: "GPUModelRunnerBase",
  238. finished_requests_ids: Optional[List[str]] = None):
  239. super().__init__()
  240. # Compute functions for each sequence in a sequence group.
  241. # WARNING: The order of the functions matters!
  242. self.per_seq_compute_fns = [
  243. self._compute_lens,
  244. self._compute_for_prefix_cache_hit,
  245. self._compute_for_sliding_window,
  246. self._compute_lora_input,
  247. ]
  248. # Compute functions for each sequence group.
  249. # WARNING: The order of the functions matters!
  250. self.per_seq_group_compute_fns = [
  251. self._compute_prompt_adapter_input,
  252. self._compute_multi_modal_input,
  253. ]
  254. self.runner = runner
  255. self.model_input_cls = self.runner._model_input_cls
  256. self.attn_backend = self.runner.attn_backend
  257. self.scheduler_config = self.runner.scheduler_config
  258. self.sliding_window = self.runner.sliding_window
  259. self.block_size = self.runner.block_size
  260. self.enable_lora = self.runner.lora_config is not None
  261. self.enable_prompt_adapter = (self.runner.prompt_adapter_config
  262. is not None)
  263. self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
  264. self.finished_requests_ids = finished_requests_ids
  265. self.decode_only = True
  266. # Intermediate data (data in CPU before going to GPU) for
  267. # the current sequence group.
  268. self.inter_data_list: List[
  269. ModelInputForGPUBuilder.InterDataForSeqGroup] = []
  270. # Attention metadata inputs.
  271. self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
  272. weakref.proxy(self))
  273. # Engine/Model configurations.
  274. self.chunked_prefill_enabled = (
  275. self.scheduler_config is not None
  276. and self.scheduler_config.chunked_prefill_enabled)
  277. if self.sliding_window is not None:
  278. self.sliding_window_blocks = (
  279. self.sliding_window + self.block_size - 1) // self.block_size
  280. self.block_aligned_sliding_window = \
  281. self.sliding_window_blocks * self.block_size
  282. def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
  283. seq_group_metadata: SequenceGroupMetadata):
  284. """Compute context length, sequence length and tokens
  285. for the given sequence data.
  286. """
  287. seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
  288. token_chunk_size = seq_group_metadata.token_chunk_size
  289. # Compute context length (the number of tokens that are
  290. # already computed) and sequence length (total number of tokens).
  291. seq_len = seq_data.get_len()
  292. if inter_data.is_prompt:
  293. context_len = seq_data.get_num_computed_tokens()
  294. else:
  295. # get_num_computed_tokens is incorrect for spec decoding.
  296. # So, we should have a special logic here.
  297. # TODO(sang): Fix it.
  298. context_len = seq_len - 1
  299. seq_len = min(seq_len, context_len + token_chunk_size)
  300. # Compute tokens.
  301. if inter_data.is_prompt:
  302. tokens = seq_data.get_token_ids()[context_len:seq_len]
  303. else:
  304. # Optimization. get_token_ids requires the entire copy of
  305. # tokens.
  306. tokens = [seq_data.get_last_token_id()]
  307. inter_data.seq_lens[seq_idx] = seq_len
  308. inter_data.orig_seq_lens[seq_idx] = seq_len
  309. inter_data.context_lens[seq_idx] = context_len
  310. inter_data.input_tokens[seq_idx] = tokens
  311. inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
  312. inter_data.query_lens[
  313. seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
  314. def _compute_for_prefix_cache_hit(
  315. self, inter_data: InterDataForSeqGroup, seq_idx: int,
  316. seq_group_metadata: SequenceGroupMetadata):
  317. """Check if hit prefix cache (i.e., some blocks are already computed).
  318. If hit, update input tokens and positions to only compute the
  319. remaining blocks.
  320. """
  321. computed_block_nums = inter_data.computed_block_nums
  322. # Note that prefix caching does not support sliding window.
  323. prefix_cache_hit = (computed_block_nums is not None
  324. and len(computed_block_nums) > 0
  325. and self.sliding_window is None
  326. and inter_data.is_prompt)
  327. inter_data.prefix_cache_hit = prefix_cache_hit
  328. if self.chunked_prefill_enabled and prefix_cache_hit:
  329. raise RuntimeError(
  330. "chunked prefill cannot be used with prefix caching now.")
  331. # If prefix cache is hit, advance context length to bypass
  332. # hit blocks. Accordingly, input tokens, position and query length
  333. # have to be updated.
  334. if prefix_cache_hit:
  335. assert computed_block_nums is not None
  336. context_len = len(computed_block_nums) * self.block_size
  337. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  338. seq_idx][context_len:]
  339. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  340. seq_idx][context_len:]
  341. inter_data.context_lens[seq_idx] = context_len
  342. inter_data.query_lens[
  343. seq_idx] = inter_data.seq_lens[seq_idx] - context_len
  344. def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
  345. seq_idx: int,
  346. seq_group_metadata: SequenceGroupMetadata):
  347. """Update seq_len and curr_sliding_window_block for the given
  348. sequence data (only required by decoding) if sliding window is enabled.
  349. """
  350. curr_sliding_window_block = 0
  351. sliding_seq_len = inter_data.seq_lens[seq_idx]
  352. if not inter_data.is_prompt and self.sliding_window is not None:
  353. # TODO(sang): This is a hack to make sliding window work with
  354. # paged attn. We can remove it if we make paged attn kernel
  355. # to properly handle slinding window attn.
  356. curr_sliding_window_block = self.sliding_window_blocks
  357. if self.scheduler_config.use_v2_block_manager:
  358. # number of elements in last block
  359. suff_len = inter_data.seq_lens[seq_idx] % self.block_size
  360. sliding_seq_len = min(
  361. inter_data.seq_lens[seq_idx],
  362. self.block_aligned_sliding_window + suff_len)
  363. if suff_len > 0:
  364. curr_sliding_window_block += 1
  365. else:
  366. sliding_seq_len = min(inter_data.seq_lens[seq_idx],
  367. self.sliding_window)
  368. inter_data.curr_sliding_window_blocks[
  369. seq_idx] = curr_sliding_window_block
  370. inter_data.seq_lens[seq_idx] = sliding_seq_len
  371. def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
  372. seq_idx: int,
  373. seq_group_metadata: SequenceGroupMetadata):
  374. """If LoRA is enabled, compute LoRA index and prompt mapping."""
  375. if not self.enable_lora:
  376. return
  377. lora_id = seq_group_metadata.lora_int_id
  378. if lora_id > 0:
  379. inter_data.lora_requests.add(seq_group_metadata.lora_request)
  380. query_len = inter_data.query_lens[seq_idx]
  381. inter_data.lora_index_mapping.append([lora_id] * query_len)
  382. inter_data.lora_prompt_mapping.append(
  383. [lora_id] *
  384. (query_len if seq_group_metadata.sampling_params
  385. and seq_group_metadata.sampling_params.prompt_logprobs is not None
  386. else 1))
  387. def _compute_prompt_adapter_input(
  388. self, inter_data: InterDataForSeqGroup,
  389. seq_group_metadata: SequenceGroupMetadata):
  390. """If prompt adapter is enabled, compute index and prompt mapping.
  391. """
  392. # Note that when is_prompt=True, we expect only one sequence
  393. # in the group.
  394. if not self.enable_prompt_adapter:
  395. return
  396. prompt_adapter_id = seq_group_metadata.prompt_adapter_id
  397. if prompt_adapter_id <= 0 or not inter_data.is_prompt:
  398. return
  399. # We expect only one sequence in the group when is_prompt=True.
  400. assert inter_data.n_seqs == 1
  401. query_len = inter_data.query_lens[0]
  402. inter_data.prompt_adapter_request = (
  403. seq_group_metadata.prompt_adapter_request)
  404. num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
  405. inter_data.prompt_adapter_index_mapping = [
  406. prompt_adapter_id
  407. ] * num_tokens + [0] * (query_len - num_tokens)
  408. inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
  409. query_len if seq_group_metadata.sampling_params
  410. and seq_group_metadata.sampling_params.prompt_logprobs else 1)
  411. def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
  412. seq_group_metadata: SequenceGroupMetadata):
  413. """If multi-modal data is given, add it to the input."""
  414. mm_data = seq_group_metadata.multi_modal_data
  415. if not mm_data:
  416. return
  417. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  418. inter_data.multi_modal_inputs = mm_kwargs
  419. def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
  420. """Add a sequence group to the builder."""
  421. seq_ids = list(seq_group_metadata.seq_data.keys())
  422. n_seqs = len(seq_ids)
  423. is_prompt = seq_group_metadata.is_prompt
  424. if is_prompt:
  425. assert n_seqs == 1
  426. self.decode_only = False
  427. inter_data = self.InterDataForSeqGroup(
  428. request_id=seq_group_metadata.request_id,
  429. seq_ids=seq_ids,
  430. is_prompt=is_prompt,
  431. block_tables=seq_group_metadata.block_tables,
  432. computed_block_nums=seq_group_metadata.computed_block_nums)
  433. self.inter_data_list.append(inter_data)
  434. for seq_idx in range(n_seqs):
  435. for per_seq_fn in self.per_seq_compute_fns:
  436. per_seq_fn(inter_data, seq_idx, seq_group_metadata)
  437. for per_seq_group_fn in self.per_seq_group_compute_fns:
  438. per_seq_group_fn(inter_data, seq_group_metadata)
  439. def _use_captured_graph(self, batch_size: int,
  440. max_decode_seq_len: int) -> bool:
  441. return (self.decode_only and not self.runner.model_config.enforce_eager
  442. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  443. and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
  444. def build(self) -> ModelInputForGPU:
  445. """Finalize the builder intermediate data and
  446. create on-device tensors.
  447. """
  448. # Combine and flatten intermediate data.
  449. input_tokens = flatten_2d_lists([
  450. flatten_2d_lists(inter_data.input_tokens)
  451. for inter_data in self.inter_data_list
  452. ])
  453. if not input_tokens:
  454. # This may happen when all prefill requests hit
  455. # prefix caching and there is no decode request.
  456. return self.model_input_cls()
  457. input_positions = flatten_2d_lists([
  458. flatten_2d_lists(inter_data.input_positions)
  459. for inter_data in self.inter_data_list
  460. ])
  461. seq_lens = []
  462. max_decode_seq_len = 0
  463. for inter_data in self.inter_data_list:
  464. seq_lens.extend(inter_data.seq_lens)
  465. if not inter_data.is_prompt:
  466. max_decode_seq_len = max(max_decode_seq_len,
  467. max(inter_data.seq_lens))
  468. query_lens = flatten_2d_lists(
  469. [inter_data.query_lens for inter_data in self.inter_data_list])
  470. # Mapping from request IDs to sequence IDs. Used for Jamba models
  471. # that manages the cache by itself.
  472. request_ids_to_seq_ids = {
  473. data.request_id: data.seq_ids
  474. for data in self.inter_data_list
  475. }
  476. batch_size = len(input_tokens)
  477. use_captured_graph = self._use_captured_graph(batch_size,
  478. max_decode_seq_len)
  479. # If cuda graph can be used, pad tensors accordingly.
  480. # See `capture_model` API for more details.
  481. # Aphrodite uses cuda graph only for decoding requests.
  482. cuda_graph_pad_size = -1
  483. if use_captured_graph:
  484. graph_batch_size = _get_graph_batch_size(batch_size)
  485. assert graph_batch_size >= batch_size
  486. cuda_graph_pad_size = graph_batch_size - batch_size
  487. batch_size = graph_batch_size
  488. # Tokens and positions.
  489. input_tokens.extend([0] * cuda_graph_pad_size)
  490. input_positions.extend([0] * cuda_graph_pad_size)
  491. assert self.runner.device is not None
  492. input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
  493. self.runner.device,
  494. self.runner.pin_memory)
  495. input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
  496. self.runner.device,
  497. self.runner.pin_memory)
  498. # Sequence and query lengths.
  499. seq_lens.extend([1] * cuda_graph_pad_size)
  500. # Attention metadata.
  501. attn_metadata = self.attn_metadata_builder.build(
  502. seq_lens, query_lens, cuda_graph_pad_size, batch_size)
  503. # LoRA data.
  504. lora_requests = set()
  505. lora_mapping = None
  506. if self.enable_lora:
  507. lora_requests = set(r for data in self.inter_data_list
  508. for r in data.lora_requests)
  509. lora_index_mapping = flatten_2d_lists([
  510. flatten_2d_lists(inter_data.lora_index_mapping)
  511. for inter_data in self.inter_data_list
  512. ])
  513. lora_index_mapping.extend([0] * cuda_graph_pad_size)
  514. lora_prompt_mapping = flatten_2d_lists([
  515. flatten_2d_lists(inter_data.lora_prompt_mapping)
  516. for inter_data in self.inter_data_list
  517. ])
  518. lora_mapping = LoRAMapping(
  519. **dict(index_mapping=lora_index_mapping,
  520. prompt_mapping=lora_prompt_mapping,
  521. is_prefill=not self.decode_only))
  522. # Prompt adapter data.
  523. prompt_adapter_requests: Set[PromptAdapterRequest] = set()
  524. prompt_adapter_mapping = None
  525. if self.enable_prompt_adapter:
  526. prompt_adapter_requests = set(
  527. data.prompt_adapter_request for data in self.inter_data_list
  528. if data.prompt_adapter_request is not None)
  529. prompt_adapter_index_mapping = flatten_2d_lists([
  530. inter_data.prompt_adapter_index_mapping
  531. for inter_data in self.inter_data_list
  532. ])
  533. prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
  534. prompt_adapter_prompt_mapping = flatten_2d_lists([
  535. inter_data.prompt_adapter_prompt_mapping
  536. for inter_data in self.inter_data_list
  537. ])
  538. prompt_adapter_mapping = PromptAdapterMapping(
  539. prompt_adapter_index_mapping,
  540. prompt_adapter_prompt_mapping,
  541. )
  542. # Multi-modal data.
  543. multi_modal_inputs_list = [
  544. data.multi_modal_inputs for data in self.inter_data_list
  545. if data.multi_modal_inputs is not None
  546. ]
  547. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  548. return self.model_input_cls(
  549. input_tokens=input_tokens_tensor,
  550. input_positions=input_positions_tensor,
  551. attn_metadata=attn_metadata,
  552. seq_lens=seq_lens,
  553. query_lens=query_lens,
  554. lora_mapping=lora_mapping,
  555. lora_requests=lora_requests,
  556. multi_modal_kwargs=multi_modal_kwargs,
  557. request_ids_to_seq_ids=request_ids_to_seq_ids,
  558. finished_requests_ids=self.finished_requests_ids,
  559. prompt_adapter_mapping=prompt_adapter_mapping,
  560. prompt_adapter_requests=prompt_adapter_requests)
  561. class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
  562. """
  563. Helper class for shared methods between GPU model runners.
  564. """
  565. _model_input_cls: Type[TModelInputForGPU]
  566. _builder_cls: Type[ModelInputForGPUBuilder]
  567. def __init__(
  568. self,
  569. model_config: ModelConfig,
  570. parallel_config: ParallelConfig,
  571. scheduler_config: SchedulerConfig,
  572. device_config: DeviceConfig,
  573. cache_config: CacheConfig,
  574. load_config: LoadConfig,
  575. lora_config: Optional[LoRAConfig],
  576. kv_cache_dtype: Optional[str] = "auto",
  577. is_driver_worker: bool = False,
  578. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  579. multimodal_config: Optional[MultiModalConfig] = None,
  580. return_hidden_states: bool = False,
  581. tp_rank: int = 0,
  582. ):
  583. self.model_config = model_config
  584. self.parallel_config = parallel_config
  585. self.scheduler_config = scheduler_config
  586. self.device_config = device_config
  587. self.cache_config = cache_config
  588. self.lora_config = lora_config
  589. self.load_config = load_config
  590. self.is_driver_worker = is_driver_worker
  591. self.prompt_adapter_config = prompt_adapter_config
  592. self.multimodal_config = multimodal_config
  593. self.return_hidden_states = return_hidden_states
  594. self.device = self.device_config.device
  595. self.pin_memory = is_pin_memory_available()
  596. self.tp_rank = tp_rank
  597. self.kv_cache_dtype = kv_cache_dtype
  598. self.sliding_window = model_config.get_sliding_window()
  599. self.block_size = cache_config.block_size
  600. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  601. self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
  602. {} for _ in range(self.parallel_config.pipeline_parallel_size)
  603. ]
  604. self.graph_memory_pool: Optional[Tuple[
  605. int, int]] = None # Set during graph capture.
  606. self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
  607. parallel_config)
  608. # When using CUDA graph, the input block tables must be padded to
  609. # max_seq_len_to_capture. However, creating the block table in
  610. # Python can be expensive. To optimize this, we cache the block table
  611. # in numpy and only copy the actual input content at every iteration.
  612. # The shape of the cached block table will be
  613. # (max batch size to capture, max context len to capture / block size).
  614. self.graph_block_tables = np.zeros(
  615. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  616. dtype=np.int32)
  617. num_attn_heads = self.model_config.get_num_attention_heads(
  618. self.parallel_config, self.tp_rank)
  619. self.attn_backend = get_attn_backend(
  620. num_attn_heads,
  621. self.model_config.get_head_size(),
  622. self.model_config.get_num_kv_heads(self.parallel_config,
  623. self.tp_rank),
  624. self.model_config.get_sliding_window(),
  625. self.model_config.dtype,
  626. self.kv_cache_dtype,
  627. self.block_size,
  628. ) if num_attn_heads else None
  629. # Multi-modal data support
  630. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  631. .create_input_mapper(self.model_config)
  632. # Lazy initialization
  633. self.model: nn.Module # Set after load_model
  634. # Set after load_model.
  635. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  636. self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
  637. self.flashinfer_decode_workspace_buffer = None
  638. self.flashinfer_decode_wrapper = None
  639. self.flashinfer_prefill_workspace_buffer = None
  640. self.flashinfer_prefill_wrapper = None
  641. set_cpu_offload_max_bytes(
  642. int(self.cache_config.cpu_offload_gb * 1024**3))
  643. def load_model(self) -> None:
  644. tp = get_tensor_model_parallel_world_size()
  645. rank = get_tensor_model_parallel_rank()
  646. if rank == 0:
  647. logger.info(f"Loading model {self.model_config.model}...")
  648. with CudaMemoryProfiler() as m:
  649. # measure the time it takes to load the model
  650. start_time = time.time()
  651. self.model = get_model(model_config=self.model_config,
  652. device_config=self.device_config,
  653. load_config=self.load_config,
  654. lora_config=self.lora_config,
  655. multimodal_config=self.multimodal_config,
  656. parallel_config=self.parallel_config,
  657. scheduler_config=self.scheduler_config,
  658. cache_config=self.cache_config)
  659. end_time = time.time()
  660. self.model_memory_usage = m.consumed_memory
  661. total_time = end_time - start_time
  662. if tp > 1:
  663. if rank == 0:
  664. logger.info(f"Model loaded in {total_time:.2f} seconds.")
  665. logger.info(
  666. "Weights memory usage: "
  667. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} ="
  668. f" {self.model_memory_usage * tp / float(2**30):.2f} GiB")
  669. else:
  670. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  671. logger.info("Weights memory usage: "
  672. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  673. if self.lora_config:
  674. assert supports_lora(self.model), "Model does not support LoRA"
  675. assert not supports_vision(
  676. self.model
  677. ), "To be tested: vision language model with LoRA settings."
  678. self.lora_manager = LRUCacheWorkerLoRAManager(
  679. self.scheduler_config.max_num_seqs,
  680. self.scheduler_config.max_num_batched_tokens,
  681. self.vocab_size,
  682. self.lora_config,
  683. self.device,
  684. self.model.embedding_modules,
  685. self.model.embedding_padding_modules,
  686. max_position_embeddings=self.model.config.
  687. max_position_embeddings,
  688. )
  689. self.model = self.lora_manager.create_lora_manager(self.model)
  690. if self.prompt_adapter_config:
  691. self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
  692. self.scheduler_config.max_num_seqs,
  693. self.scheduler_config.max_num_batched_tokens, self.device,
  694. self.prompt_adapter_config)
  695. self.model = (
  696. self.prompt_adapter_manager.create_prompt_adapter_manager(
  697. self.model))
  698. if self.kv_cache_dtype == "fp8" and is_hip():
  699. # Currently only ROCm accepts kv-cache scaling factors
  700. # via quantization_param_path and this will be deprecated
  701. # in the future.
  702. if self.model_config.quantization_param_path is not None:
  703. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  704. warnings.warn(
  705. "Loading kv cache scaling factor from JSON is "
  706. "deprecated and will be removed. Please include "
  707. "kv cache scaling factors in the model checkpoint.",
  708. FutureWarning,
  709. stacklevel=2)
  710. self.model.load_kv_cache_scales(
  711. self.model_config.quantization_param_path)
  712. logger.info(
  713. "Loaded KV cache scaling factors from ",
  714. f"{self.model_config.quantization_param_path}")
  715. else:
  716. raise RuntimeError(
  717. "Using FP8 KV cache and scaling factors provided but "
  718. f"model {self.model.__class__} does not support loading"
  719. " scaling factors.", )
  720. else:
  721. logger.warning(
  722. "Using FP8 KV cache but no scaling factors "
  723. "provided. Defaulting to scaling factors of 1.0. "
  724. "This may lead to less accurate results!")
  725. if APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE:
  726. logger.info("Compiling the model using torch.compile...")
  727. start_time = time.time()
  728. self.model = torch.compile(self.model,
  729. fullgraph=True,
  730. backend="eager")
  731. end_time = time.time()
  732. logger.info(
  733. f"Model compiled in {end_time - start_time:.2f} seconds.")
  734. def get_model_memory_usage(self):
  735. return self.model_memory_usage
  736. def save_sharded_state(
  737. self,
  738. path: str,
  739. pattern: Optional[str] = None,
  740. max_size: Optional[int] = None,
  741. ) -> None:
  742. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  743. ShardedStateLoader.save_model(
  744. self.model,
  745. path,
  746. pattern=pattern,
  747. max_size=max_size,
  748. )
  749. def save_tensorized_model(
  750. self,
  751. tensorizer_config: TensorizerConfig,
  752. ) -> None:
  753. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  754. TensorizerLoader.save_model(
  755. self.model,
  756. tensorizer_config=tensorizer_config,
  757. )
  758. def get_max_block_per_batch(self) -> int:
  759. block_size = self.block_size
  760. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  761. def _prepare_model_input_tensors(
  762. self,
  763. seq_group_metadata_list: List[SequenceGroupMetadata],
  764. finished_requests_ids: Optional[List[str]] = None
  765. ) -> TModelInputForGPU:
  766. """Helper method to prepare the model input based on a given sequence
  767. group. Prepares metadata needed for the base model forward pass but not
  768. metadata for possible additional steps, e.g., sampling.
  769. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  770. The result tensors and data structure also batches input in prefill
  771. -> decode order. For example,
  772. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  773. - input_tokens[num_prefill_tokens:] contains decode tokens.
  774. If cuda graph is required, this API automatically pads inputs.
  775. """
  776. builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
  777. for seq_group_metadata in seq_group_metadata_list:
  778. builder.add_seq_group(seq_group_metadata)
  779. return builder.build() # type: ignore
  780. @torch.inference_mode()
  781. def profile_run(self) -> None:
  782. rank = get_tensor_model_parallel_rank()
  783. if rank == 0:
  784. logger.info("Profiling peak memory usage...")
  785. # Enable top-k sampling to reflect the accurate memory usage.
  786. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  787. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  788. max_num_seqs = self.scheduler_config.max_num_seqs
  789. # This represents the maximum number of different requests
  790. # that will have unique loras, an therefore the max amount of memory
  791. # consumption create dummy lora request copies from the lora request
  792. # passed in, which contains a lora from the lora warmup path.
  793. dummy_lora_requests: List[LoRARequest] = []
  794. dummy_lora_requests_per_seq: List[LoRARequest] = []
  795. if self.lora_config:
  796. assert self.lora_manager is not None
  797. with self.lora_manager.dummy_lora_cache():
  798. for idx in range(self.lora_config.max_loras):
  799. lora_id = idx + 1
  800. dummy_lora_request = LoRARequest(
  801. lora_name=f"warmup_{lora_id}",
  802. lora_int_id=lora_id,
  803. lora_local_path="/not/a/real/path",
  804. )
  805. self.lora_manager.add_dummy_lora(dummy_lora_request,
  806. rank=LORA_WARMUP_RANK)
  807. dummy_lora_requests.append(dummy_lora_request)
  808. dummy_lora_requests_per_seq = [
  809. dummy_lora_requests[idx % len(dummy_lora_requests)]
  810. for idx in range(max_num_seqs)
  811. ]
  812. # Profile memory usage with max_num_sequences sequences and the total
  813. # number of tokens equal to max_num_batched_tokens.
  814. seqs: List[SequenceGroupMetadata] = []
  815. # Additional GPU memory may be needed for vision encoding, which needs
  816. # to be accounted for when calculating the GPU blocks for
  817. # Aphrodite blocker manager.
  818. # To exercise the worst scenario for GPU memory consumption,
  819. # the number of seqs (batch_size) is chosen to maximize the number
  820. # of images processed.
  821. model_config = self.model_config
  822. if supports_vision(self.model):
  823. max_mm_tokens = MULTIMODAL_REGISTRY \
  824. .get_max_multimodal_tokens(model_config)
  825. max_num_seqs_orig = max_num_seqs
  826. max_num_seqs = min(max_num_seqs,
  827. max_num_batched_tokens // max_mm_tokens)
  828. if max_num_seqs < 1:
  829. expr = (f"min({max_num_seqs_orig}, "
  830. f"{max_num_batched_tokens} // {max_mm_tokens})")
  831. logger.warning(
  832. f"Computed max_num_seqs ({expr}) to be less than 1. "
  833. "Setting it to the minimum value of 1.")
  834. max_num_seqs = 1
  835. batch_size = 0
  836. for group_id in range(max_num_seqs):
  837. seq_len = (max_num_batched_tokens // max_num_seqs +
  838. (group_id < max_num_batched_tokens % max_num_seqs))
  839. batch_size += seq_len
  840. seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
  841. .dummy_data_for_profiling(model_config, seq_len)
  842. # Having more tokens is over-conservative but otherwise fine
  843. assert len(seq_data.prompt_token_ids) >= seq_len, (
  844. f"Expected at least {seq_len} dummy tokens for profiling, "
  845. f"but got: {len(seq_data.prompt_token_ids)}")
  846. seq = SequenceGroupMetadata(
  847. request_id=str(group_id),
  848. is_prompt=True,
  849. seq_data={group_id: seq_data},
  850. sampling_params=sampling_params,
  851. block_tables=None,
  852. lora_request=dummy_lora_requests_per_seq[group_id]
  853. if dummy_lora_requests_per_seq else None,
  854. multi_modal_data=dummy_multi_modal_data,
  855. )
  856. seqs.append(seq)
  857. # Run the model with the dummy inputs.
  858. num_layers = self.model_config.get_num_layers(self.parallel_config)
  859. kv_caches = [None] * num_layers
  860. finished_requests_ids = [seq.request_id for seq in seqs]
  861. model_input = self.prepare_model_input(
  862. seqs, finished_requests_ids=finished_requests_ids)
  863. intermediate_tensors = None
  864. if not get_pp_group().is_first_rank:
  865. intermediate_tensors = self.model.make_empty_intermediate_tensors(
  866. batch_size=batch_size,
  867. dtype=self.model_config.dtype,
  868. device=self.device)
  869. self.execute_model(model_input, kv_caches, intermediate_tensors)
  870. torch.cuda.synchronize()
  871. return
  872. def remove_all_loras(self):
  873. if not self.lora_manager:
  874. raise RuntimeError("LoRA is not enabled.")
  875. self.lora_manager.remove_all_adapters()
  876. def set_active_loras(self, lora_requests: Set[LoRARequest],
  877. lora_mapping: LoRAMapping) -> None:
  878. if not self.lora_manager:
  879. raise RuntimeError("LoRA is not enabled.")
  880. self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
  881. def add_lora(self, lora_request: LoRARequest) -> bool:
  882. if not self.lora_manager:
  883. raise RuntimeError("LoRA is not enabled.")
  884. return self.lora_manager.add_adapter(lora_request)
  885. def remove_lora(self, lora_id: int) -> bool:
  886. if not self.lora_manager:
  887. raise RuntimeError("LoRA is not enabled.")
  888. return self.lora_manager.remove_adapter(lora_id)
  889. def pin_lora(self, lora_id: int) -> bool:
  890. if not self.lora_manager:
  891. raise RuntimeError("LoRA is not enabled.")
  892. return self.lora_manager.pin_adapter(lora_id)
  893. def list_loras(self) -> Set[int]:
  894. if not self.lora_manager:
  895. raise RuntimeError("LoRA is not enabled.")
  896. return self.lora_manager.list_adapters()
  897. def remove_all_prompt_adapters(self):
  898. if not self.prompt_adapter_manager:
  899. raise RuntimeError("PromptAdapter is not enabled.")
  900. self.prompt_adapter_manager.remove_all_adapters()
  901. def set_active_prompt_adapters(
  902. self, prompt_adapter_requests: Set[PromptAdapterRequest],
  903. prompt_adapter_mapping: PromptAdapterMapping) -> None:
  904. if not self.prompt_adapter_manager:
  905. raise RuntimeError("PromptAdapter is not enabled.")
  906. self.prompt_adapter_manager.set_active_adapters(
  907. prompt_adapter_requests, prompt_adapter_mapping)
  908. def add_prompt_adapter(
  909. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  910. if not self.prompt_adapter_manager:
  911. raise RuntimeError("PromptAdapter is not enabled.")
  912. return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
  913. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  914. if not self.prompt_adapter_manager:
  915. raise RuntimeError("PromptAdapter is not enabled.")
  916. return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
  917. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  918. if not self.prompt_adapter_manager:
  919. raise RuntimeError("PromptAdapter is not enabled.")
  920. return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
  921. def list_prompt_adapters(self) -> Set[int]:
  922. if not self.prompt_adapter_manager:
  923. raise RuntimeError("PromptAdapter is not enabled.")
  924. return self.prompt_adapter_manager.list_adapters()
  925. @torch.inference_mode()
  926. def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
  927. """Cuda graph capture a model.
  928. Note that CUDA graph's performance gain is negligible if number
  929. of batched tokens are larger than 200. And since CUDA graph
  930. requires fixed sized tensors, supporting large/variable batch
  931. size requires high GPU memory overhead. Thus, Aphrodite only captures
  932. decoding requests. Mixed batch (chunked prefill + decoding) or
  933. prefill requests are not captured.
  934. Since it is used for decoding-only, it assumes there's only 1 token
  935. per sequence in the batch.
  936. """
  937. tp_rank = get_tensor_model_parallel_rank()
  938. assert not self.model_config.enforce_eager
  939. if tp_rank == 0:
  940. logger.info(
  941. "Capturing the model for CUDA graphs. This may lead to "
  942. "unexpected consequences if the model is not static. To "
  943. "run the model in eager mode, set 'enforce_eager=True' or "
  944. "use '--enforce-eager' in the CLI.")
  945. logger.info(
  946. "CUDA graphs can take additional 1~3 GiB memory per GPU. "
  947. "If you are running out of memory, consider decreasing "
  948. "`gpu_memory_utilization` or enforcing eager mode. "
  949. "You can also reduce the `max_num_seqs` as needed "
  950. "to decrease memory usage.")
  951. start_time = time.perf_counter()
  952. # Prepare dummy inputs. These will be reused for all batch sizes.
  953. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  954. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  955. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  956. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  957. slot_mapping.fill_(_PAD_SLOT_ID)
  958. seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  959. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  960. intermediate_inputs = None
  961. if not get_pp_group().is_first_rank:
  962. intermediate_inputs = self.model.make_empty_intermediate_tensors(
  963. batch_size=max_batch_size,
  964. dtype=self.model_config.dtype,
  965. device=self.device)
  966. # Prepare buffer for outputs. These will be reused for all batch sizes.
  967. # It will be filled after the first graph capture.
  968. hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
  969. None
  970. ] * self.parallel_config.pipeline_parallel_size
  971. graph_batch_size = _get_graph_batch_size(
  972. self.scheduler_config.max_num_seqs)
  973. batch_size_capture_list = [
  974. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  975. ]
  976. if self.attn_backend.get_name() == "flashinfer":
  977. # For flashinfer, different batch sizes will share the
  978. # same workspace buffer.
  979. decode_workspace_buffer = \
  980. torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
  981. dtype=torch.uint8,
  982. device=self.device)
  983. indices_buffer = torch.empty(max_batch_size *
  984. self.cache_config.num_gpu_blocks,
  985. dtype=torch.int32,
  986. device=self.device)
  987. indptr_buffer = torch.empty(max_batch_size + 1,
  988. dtype=torch.int32,
  989. device=self.device)
  990. last_page_len_buffer = torch.empty(max_batch_size,
  991. dtype=torch.int32,
  992. device=self.device)
  993. with graph_capture() as graph_capture_context:
  994. # NOTE: Capturing the largest batch size first may help reduce the
  995. # memory usage of CUDA graph.
  996. for virtual_engine in range(
  997. self.parallel_config.pipeline_parallel_size):
  998. for batch_size in reversed(batch_size_capture_list):
  999. if self.attn_backend.get_name() == "flashinfer":
  1000. _indptr_buffer = indptr_buffer[:batch_size + 1]
  1001. _last_page_len_buffer = last_page_len_buffer[:
  1002. batch_size]
  1003. num_qo_heads = (
  1004. self.model_config.get_num_attention_heads(
  1005. self.parallel_config, self.tp_rank))
  1006. num_kv_heads = self.model_config.get_num_kv_heads(
  1007. self.parallel_config, self.tp_rank)
  1008. if num_qo_heads // num_kv_heads >= 4:
  1009. use_tensor_cores = True
  1010. else:
  1011. use_tensor_cores = False
  1012. decode_wrapper = \
  1013. CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
  1014. decode_workspace_buffer, _indptr_buffer,
  1015. indices_buffer, _last_page_len_buffer, "NHD",
  1016. use_tensor_cores)
  1017. kv_cache_dtype = get_kv_cache_torch_dtype(
  1018. self.kv_cache_dtype, self.model_config.dtype)
  1019. paged_kv_indptr_tensor_host = torch.arange(
  1020. 0, batch_size + 1, dtype=torch.int32)
  1021. paged_kv_indices_tensor_host = torch.arange(
  1022. 0, batch_size, dtype=torch.int32)
  1023. paged_kv_last_page_len_tensor_host = torch.full(
  1024. (batch_size, ), self.block_size, dtype=torch.int32)
  1025. query_start_loc_host = torch.arange(0,
  1026. batch_size + 1,
  1027. dtype=torch.int32)
  1028. attn_metadata = self.attn_backend.make_metadata(
  1029. num_prefills=0,
  1030. slot_mapping=slot_mapping[:batch_size],
  1031. num_prefill_tokens=0,
  1032. num_decode_tokens=batch_size,
  1033. max_prefill_seq_len=0,
  1034. block_tables=block_tables,
  1035. paged_kv_indptr=paged_kv_indptr_tensor_host,
  1036. paged_kv_indices=paged_kv_indices_tensor_host,
  1037. paged_kv_last_page_len=
  1038. paged_kv_last_page_len_tensor_host,
  1039. num_qo_heads=num_qo_heads,
  1040. num_kv_heads=num_kv_heads,
  1041. head_dim=self.model_config.get_head_size(),
  1042. page_size=self.block_size,
  1043. seq_start_loc=None,
  1044. query_start_loc=query_start_loc_host,
  1045. device=self.device,
  1046. data_type=kv_cache_dtype,
  1047. use_cuda_graph=True,
  1048. decode_wrapper=decode_wrapper,
  1049. prefill_wrapper=None)
  1050. attn_metadata.begin_forward()
  1051. else:
  1052. attn_metadata = self.attn_backend.make_metadata(
  1053. num_prefills=0,
  1054. num_prefill_tokens=0,
  1055. num_decode_tokens=batch_size,
  1056. slot_mapping=slot_mapping[:batch_size],
  1057. seq_lens=None,
  1058. seq_lens_tensor=seq_lens[:batch_size],
  1059. max_query_len=None,
  1060. max_prefill_seq_len=0,
  1061. max_decode_seq_len=self.max_seq_len_to_capture,
  1062. query_start_loc=None,
  1063. seq_start_loc=None,
  1064. context_lens_tensor=None,
  1065. block_tables=block_tables[:batch_size],
  1066. use_cuda_graph=True,
  1067. )
  1068. if self.lora_config:
  1069. lora_mapping = LoRAMapping(
  1070. **dict(index_mapping=[0] * batch_size,
  1071. prompt_mapping=[0] * batch_size,
  1072. is_prefill=False))
  1073. self.set_active_loras(set(), lora_mapping)
  1074. if self.prompt_adapter_config:
  1075. prompt_adapter_mapping = PromptAdapterMapping(
  1076. [-1] * batch_size,
  1077. [-1] * batch_size,
  1078. )
  1079. self.set_active_prompt_adapters(
  1080. set(), prompt_adapter_mapping)
  1081. graph_runner = CUDAGraphRunner(
  1082. self.model, self.attn_backend.get_name())
  1083. if self.attn_backend.get_name() == "flashinfer":
  1084. graph_runner.flashinfer_indptr_buffer = _indptr_buffer
  1085. graph_runner.flashinfer_indices_buffer = indices_buffer
  1086. graph_runner.flashinfer_last_page_len_buffer = \
  1087. _last_page_len_buffer
  1088. graph_runner.flashinfer_decode_workspace_buffer = \
  1089. decode_workspace_buffer
  1090. graph_runner.flashinfer_decode_wrapper = \
  1091. decode_wrapper
  1092. capture_inputs = {
  1093. "input_ids":
  1094. input_tokens[:batch_size],
  1095. "positions":
  1096. input_positions[:batch_size],
  1097. "hidden_or_intermediate_states":
  1098. hidden_or_intermediate_states[
  1099. virtual_engine] # type: ignore
  1100. [:batch_size]
  1101. if hidden_or_intermediate_states[virtual_engine]
  1102. is not None else None,
  1103. "intermediate_inputs":
  1104. intermediate_inputs[:batch_size]
  1105. if intermediate_inputs is not None else None,
  1106. "kv_caches":
  1107. kv_caches[virtual_engine],
  1108. "attn_metadata":
  1109. attn_metadata,
  1110. "memory_pool":
  1111. self.graph_memory_pool,
  1112. "stream":
  1113. graph_capture_context.stream
  1114. }
  1115. if self.has_seqlen_agnostic:
  1116. # Only used by Mamba-based models CUDA graph atm (Jamba)
  1117. capture_inputs.update({
  1118. "seqlen_agnostic_capture_inputs":
  1119. self.model.get_seqlen_agnostic_capture_inputs(
  1120. batch_size)
  1121. })
  1122. graph_runner.capture(**capture_inputs)
  1123. self.graph_memory_pool = graph_runner.graph.pool()
  1124. self.graph_runners[virtual_engine][batch_size] = (
  1125. graph_runner)
  1126. end_time = time.perf_counter()
  1127. elapsed_time = end_time - start_time
  1128. # This usually takes < 10 seconds.
  1129. if tp_rank == 0:
  1130. logger.info(f"Graph capturing finished in {elapsed_time:.2f} secs")
  1131. @property
  1132. def vocab_size(self) -> int:
  1133. return self.model_config.get_vocab_size()
  1134. class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
  1135. """
  1136. GPU model runner with sampling step.
  1137. """
  1138. _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
  1139. ModelInputForGPUWithSamplingMetadata)
  1140. _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
  1141. def make_model_input_from_broadcasted_tensor_dict(
  1142. self,
  1143. tensor_dict: Dict[str, Any],
  1144. ) -> ModelInputForGPUWithSamplingMetadata:
  1145. model_input = \
  1146. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  1147. tensor_dict,
  1148. attn_backend=self.attn_backend,
  1149. )
  1150. return model_input
  1151. def prepare_model_input(
  1152. self,
  1153. seq_group_metadata_list: List[SequenceGroupMetadata],
  1154. virtual_engine: int = 0,
  1155. finished_requests_ids: Optional[List[str]] = None
  1156. ) -> ModelInputForGPUWithSamplingMetadata:
  1157. """Prepare the model input based on a given sequence group, including
  1158. metadata for the sampling step.
  1159. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  1160. The result tensors and data structure also batches input in prefill
  1161. -> decode order. For example,
  1162. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1163. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1164. If cuda graph is required, this API automatically pads inputs.
  1165. """
  1166. model_input = self._prepare_model_input_tensors(
  1167. seq_group_metadata_list, finished_requests_ids)
  1168. if get_pp_group().is_last_rank:
  1169. # Sampling metadata is only required for the final pp group
  1170. generators = self.get_generators(finished_requests_ids)
  1171. sampling_metadata = SamplingMetadata.prepare(
  1172. seq_group_metadata_list, model_input.seq_lens,
  1173. model_input.query_lens, self.device, self.pin_memory,
  1174. generators)
  1175. else:
  1176. sampling_metadata = None
  1177. is_prompt = (seq_group_metadata_list[0].is_prompt
  1178. if seq_group_metadata_list else None)
  1179. return dataclasses.replace(model_input,
  1180. sampling_metadata=sampling_metadata,
  1181. is_prompt=is_prompt,
  1182. virtual_engine=virtual_engine)
  1183. @torch.inference_mode()
  1184. def execute_model(
  1185. self,
  1186. model_input: ModelInputForGPUWithSamplingMetadata,
  1187. kv_caches: List[torch.Tensor],
  1188. intermediate_tensors: Optional[IntermediateTensors] = None,
  1189. num_steps: int = 1,
  1190. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  1191. if num_steps > 1:
  1192. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  1193. if self.lora_config:
  1194. assert model_input.lora_requests is not None
  1195. assert model_input.lora_mapping is not None
  1196. self.set_active_loras(model_input.lora_requests,
  1197. model_input.lora_mapping)
  1198. if self.prompt_adapter_config:
  1199. assert model_input.prompt_adapter_requests is not None
  1200. assert model_input.prompt_adapter_mapping is not None
  1201. self.set_active_prompt_adapters(
  1202. model_input.prompt_adapter_requests,
  1203. model_input.prompt_adapter_mapping)
  1204. if self.attn_backend.get_name() == "flashinfer":
  1205. assert model_input.attn_metadata is not None
  1206. assert model_input.input_tokens is not None
  1207. if self.flashinfer_decode_workspace_buffer is None:
  1208. self.flashinfer_decode_workspace_buffer = torch.empty(
  1209. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  1210. dtype=torch.uint8,
  1211. device=self.device)
  1212. self.flashinfer_decode_wrapper = \
  1213. BatchDecodeWithPagedKVCacheWrapper(
  1214. self.flashinfer_decode_workspace_buffer, "NHD")
  1215. self.flashinfer_prefill_workspace_buffer = torch.empty(
  1216. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  1217. dtype=torch.uint8,
  1218. device=self.device)
  1219. self.flashinfer_prefill_wrapper = \
  1220. BatchPrefillWithPagedKVCacheWrapper(
  1221. self.flashinfer_prefill_workspace_buffer, "NHD")
  1222. model_input.attn_metadata.prefill_wrapper = \
  1223. self.flashinfer_prefill_wrapper
  1224. if model_input.attn_metadata.use_cuda_graph:
  1225. batch_size = model_input.input_tokens.shape[0]
  1226. model_input.attn_metadata.decode_wrapper = self.graph_runners[
  1227. model_input.
  1228. virtual_engine][batch_size].flashinfer_decode_wrapper
  1229. else:
  1230. model_input.attn_metadata.decode_wrapper = \
  1231. self.flashinfer_decode_wrapper
  1232. model_input.attn_metadata.begin_forward()
  1233. # Currently cuda graph is only supported by the decode phase.
  1234. assert model_input.attn_metadata is not None
  1235. prefill_meta = model_input.attn_metadata.prefill_metadata
  1236. decode_meta = model_input.attn_metadata.decode_metadata
  1237. # TODO: We can remove this once all
  1238. # virtual engines share the same kv cache.
  1239. virtual_engine = model_input.virtual_engine
  1240. if prefill_meta is None and decode_meta.use_cuda_graph:
  1241. assert model_input.input_tokens is not None
  1242. graph_batch_size = model_input.input_tokens.shape[0]
  1243. model_executable = self.graph_runners[virtual_engine][
  1244. graph_batch_size]
  1245. else:
  1246. model_executable = self.model
  1247. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  1248. seqlen_agnostic_kwargs = {
  1249. "finished_requests_ids": model_input.finished_requests_ids,
  1250. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  1251. } if self.has_seqlen_agnostic else {}
  1252. hidden_or_intermediate_states = model_executable(
  1253. input_ids=model_input.input_tokens,
  1254. positions=model_input.input_positions,
  1255. kv_caches=kv_caches,
  1256. attn_metadata=model_input.attn_metadata,
  1257. intermediate_tensors=intermediate_tensors,
  1258. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  1259. device=self.device),
  1260. **seqlen_agnostic_kwargs,
  1261. )
  1262. # Compute the logits in the last pipeline stage.
  1263. if not get_pp_group().is_last_rank:
  1264. return hidden_or_intermediate_states
  1265. logits = self.model.compute_logits(hidden_or_intermediate_states,
  1266. model_input.sampling_metadata)
  1267. if not self.is_driver_worker:
  1268. return []
  1269. # Sample the next token.
  1270. output: SamplerOutput = self.model.sample(
  1271. logits=logits,
  1272. sampling_metadata=model_input.sampling_metadata,
  1273. )
  1274. if self.return_hidden_states:
  1275. # we only need to pass hidden states of most recent token
  1276. assert model_input.sampling_metadata is not None
  1277. indices = model_input.sampling_metadata.selected_token_indices
  1278. if model_input.is_prompt:
  1279. hidden_states = hidden_or_intermediate_states.index_select(
  1280. 0, indices)
  1281. elif decode_meta.use_cuda_graph:
  1282. hidden_states = hidden_or_intermediate_states[:len(indices)]
  1283. else:
  1284. hidden_states = hidden_or_intermediate_states
  1285. output.hidden_states = hidden_states
  1286. return [output]
  1287. class CUDAGraphRunner:
  1288. def __init__(self, model: nn.Module, backend_name: str):
  1289. self.model = model
  1290. self.backend_name = backend_name
  1291. self.input_buffers: Dict[str, torch.Tensor] = {}
  1292. self.output_buffers: Dict[str, torch.Tensor] = {}
  1293. self._graph: Optional[torch.cuda.CUDAGraph] = None
  1294. self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
  1295. self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
  1296. self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
  1297. self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
  1298. self.flashinfer_decode_wrapper: Optional[
  1299. CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
  1300. @property
  1301. def graph(self):
  1302. assert self._graph is not None
  1303. return self._graph
  1304. def capture(
  1305. self,
  1306. input_ids: torch.Tensor,
  1307. positions: torch.Tensor,
  1308. hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
  1309. torch.Tensor]],
  1310. intermediate_inputs: Optional[IntermediateTensors],
  1311. kv_caches: List[torch.Tensor],
  1312. attn_metadata: AttentionMetadata,
  1313. memory_pool: Optional[Tuple[int, int]],
  1314. stream: torch.cuda.Stream,
  1315. **kwargs,
  1316. ) -> Union[torch.Tensor, IntermediateTensors]:
  1317. assert self._graph is None
  1318. # Run the model a few times without capturing the graph.
  1319. # This is to make sure that the captured graph does not include the
  1320. # kernel launches for initial benchmarking (e.g., Triton autotune).
  1321. # Note one iteration is not enough for torch.jit.script
  1322. for _ in range(_NUM_WARMUP_ITERS):
  1323. self.model(
  1324. input_ids,
  1325. positions,
  1326. kv_caches,
  1327. attn_metadata,
  1328. intermediate_inputs,
  1329. **kwargs,
  1330. )
  1331. torch.cuda.synchronize()
  1332. # Capture the graph.
  1333. self._graph = torch.cuda.CUDAGraph()
  1334. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  1335. output_hidden_or_intermediate_states = self.model(
  1336. input_ids,
  1337. positions,
  1338. kv_caches,
  1339. attn_metadata,
  1340. intermediate_inputs,
  1341. **kwargs,
  1342. )
  1343. if hidden_or_intermediate_states is not None:
  1344. if get_pp_group().is_last_rank:
  1345. hidden_or_intermediate_states.copy_(
  1346. output_hidden_or_intermediate_states)
  1347. else:
  1348. for key in hidden_or_intermediate_states.tensors:
  1349. hidden_or_intermediate_states[key].copy_(
  1350. output_hidden_or_intermediate_states[key])
  1351. else:
  1352. hidden_or_intermediate_states = (
  1353. output_hidden_or_intermediate_states)
  1354. del output_hidden_or_intermediate_states
  1355. # make sure `output_hidden_states` is deleted
  1356. # in the graph's memory pool
  1357. gc.collect()
  1358. torch.cuda.synchronize()
  1359. # Save the input and output buffers.
  1360. if self.backend_name == "flashinfer":
  1361. self.input_buffers = {
  1362. "input_ids": input_ids,
  1363. "positions": positions,
  1364. "kv_caches": kv_caches,
  1365. "slot_mapping": attn_metadata.slot_mapping,
  1366. **kwargs,
  1367. }
  1368. else:
  1369. self.input_buffers = {
  1370. "input_ids": input_ids,
  1371. "positions": positions,
  1372. "kv_caches": kv_caches,
  1373. "slot_mapping": attn_metadata.slot_mapping,
  1374. "seq_lens_tensor":
  1375. attn_metadata.decode_metadata.seq_lens_tensor,
  1376. "block_tables": attn_metadata.decode_metadata.block_tables,
  1377. **kwargs,
  1378. }
  1379. if intermediate_inputs is not None:
  1380. self.input_buffers.update(intermediate_inputs.tensors)
  1381. if get_pp_group().is_last_rank:
  1382. self.output_buffers = {
  1383. "hidden_states": hidden_or_intermediate_states
  1384. }
  1385. else:
  1386. self.output_buffers = hidden_or_intermediate_states
  1387. return hidden_or_intermediate_states
  1388. def forward(
  1389. self,
  1390. input_ids: torch.Tensor,
  1391. positions: torch.Tensor,
  1392. kv_caches: List[torch.Tensor],
  1393. attn_metadata: AttentionMetadata,
  1394. intermediate_tensors: Optional[IntermediateTensors],
  1395. **kwargs,
  1396. ) -> torch.Tensor:
  1397. # KV caches are fixed tensors, so we don't need to copy them.
  1398. del kv_caches
  1399. # Copy the input tensors to the input buffers.
  1400. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1401. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1402. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  1403. non_blocking=True)
  1404. if self.backend_name != "flashinfer":
  1405. self.input_buffers["seq_lens_tensor"].copy_(
  1406. attn_metadata.decode_metadata.seq_lens_tensor,
  1407. non_blocking=True)
  1408. self.input_buffers["block_tables"].copy_(
  1409. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  1410. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1411. self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
  1412. **kwargs)
  1413. if intermediate_tensors is not None:
  1414. for key in intermediate_tensors.tensors:
  1415. self.input_buffers[key].copy_(intermediate_tensors[key],
  1416. non_blocking=True)
  1417. # Run the graph.
  1418. self.graph.replay()
  1419. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1420. self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
  1421. **kwargs)
  1422. # Return the output tensor.
  1423. if get_pp_group().is_last_rank:
  1424. return self.output_buffers["hidden_states"]
  1425. return self.output_buffers
  1426. def __call__(self, *args, **kwargs):
  1427. return self.forward(*args, **kwargs)
  1428. def _get_graph_batch_size(batch_size: int) -> int:
  1429. """Returns the padded batch size given actual batch size.
  1430. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1431. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1432. """
  1433. if batch_size <= 2:
  1434. return batch_size
  1435. elif batch_size <= 4:
  1436. return 4
  1437. else:
  1438. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1439. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)