1
0

model_runner.py 78 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804
  1. import dataclasses
  2. import gc
  3. import inspect
  4. import itertools
  5. import time
  6. import warnings
  7. import weakref
  8. from dataclasses import dataclass
  9. from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
  10. Tuple, Type, TypeVar, Union)
  11. import numpy as np
  12. import torch
  13. import torch.distributed
  14. import torch.nn as nn
  15. from loguru import logger
  16. import aphrodite.common.envs as envs
  17. from aphrodite.attention import AttentionMetadata, get_attn_backend
  18. from aphrodite.attention.backends.abstract import AttentionState
  19. from aphrodite.attention.backends.utils import CommonAttentionState
  20. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  21. LoRAConfig, ModelConfig, ParallelConfig,
  22. PromptAdapterConfig, SchedulerConfig)
  23. from aphrodite.common.sampling_params import SamplingParams
  24. from aphrodite.common.sequence import (IntermediateTensors,
  25. SequenceGroupMetadata)
  26. from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
  27. async_tensor_h2d, flatten_2d_lists, is_hip,
  28. is_pin_memory_available, supports_dynamo)
  29. from aphrodite.distributed import get_pp_group
  30. from aphrodite.distributed.parallel_state import (
  31. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
  32. graph_capture)
  33. from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
  34. from aphrodite.lora.layers import LoRAMapping
  35. from aphrodite.lora.request import LoRARequest
  36. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  37. from aphrodite.modeling.layers.rotary_embedding import MRotaryEmbedding
  38. from aphrodite.modeling.layers.sampler import SamplerOutput
  39. from aphrodite.modeling.model_loader import get_model
  40. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  41. from aphrodite.modeling.models.interfaces import (supports_lora,
  42. supports_multimodal)
  43. from aphrodite.modeling.models.utils import set_cpu_offload_max_bytes
  44. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  45. SamplingMetadataCache)
  46. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  47. MultiModalInputs, MultiModalRegistry)
  48. from aphrodite.processing.scheduler import SchedulerOutputs
  49. from aphrodite.prompt_adapter.layers import PromptAdapterMapping
  50. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  51. from aphrodite.prompt_adapter.worker_manager import (
  52. LRUCacheWorkerPromptAdapterManager)
  53. from aphrodite.worker.model_runner_base import (
  54. ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
  55. _add_attn_metadata_broadcastable_dict,
  56. _add_sampling_metadata_broadcastable_dict,
  57. _init_attn_metadata_from_tensor_dict,
  58. _init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
  59. if TYPE_CHECKING:
  60. from aphrodite.attention.backends.abstract import AttentionBackend
  61. LORA_WARMUP_RANK = 8
  62. _BATCH_SIZE_ALIGNMENT = 8
  63. # all the token sizes that **can** be captured by cudagraph.
  64. # they can be arbitrarily large.
  65. # currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
  66. # the actual sizes to capture will be determined by the model,
  67. # depending on the model's max_num_seqs.
  68. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  69. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  70. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
  71. ]
  72. _NUM_WARMUP_ITERS = 2
  73. TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
  74. # For now, bump up cache limits for recompilations during CUDA graph warmups.
  75. torch._dynamo.config.cache_size_limit = 128
  76. torch._dynamo.config.accumulated_cache_size_limit = 128
  77. @dataclass(frozen=True)
  78. class ModelInputForGPU(ModelRunnerInputBase):
  79. """
  80. This base class contains metadata needed for the base model forward pass
  81. but not metadata for possible additional steps, e.g., sampling. Model
  82. runners that run additional steps should subclass this method to add
  83. additional fields.
  84. """
  85. input_tokens: Optional[torch.Tensor] = None
  86. input_positions: Optional[torch.Tensor] = None
  87. seq_lens: Optional[List[int]] = None
  88. query_lens: Optional[List[int]] = None
  89. lora_mapping: Optional["LoRAMapping"] = None
  90. lora_requests: Optional[Set[LoRARequest]] = None
  91. attn_metadata: Optional["AttentionMetadata"] = None
  92. prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
  93. prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
  94. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  95. request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
  96. finished_requests_ids: Optional[List[str]] = None
  97. virtual_engine: int = 0
  98. async_callback: Optional[Callable] = None
  99. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
  100. scheduler_outputs: Optional[SchedulerOutputs] = None
  101. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  102. tensor_dict = {
  103. "input_tokens": self.input_tokens,
  104. "input_positions": self.input_positions,
  105. "lora_requests": self.lora_requests,
  106. "lora_mapping": self.lora_mapping,
  107. "multi_modal_kwargs": self.multi_modal_kwargs,
  108. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  109. "prompt_adapter_requests": self.prompt_adapter_requests,
  110. "virtual_engine": self.virtual_engine,
  111. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  112. "finished_requests_ids": self.finished_requests_ids,
  113. }
  114. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  115. return tensor_dict
  116. @classmethod
  117. def from_broadcasted_tensor_dict(
  118. cls: Type[TModelInputForGPU],
  119. tensor_dict: Dict[str, Any],
  120. attn_backend: Optional["AttentionBackend"] = None,
  121. ) -> TModelInputForGPU:
  122. if attn_backend is not None:
  123. tensor_dict = _init_attn_metadata_from_tensor_dict(
  124. attn_backend, tensor_dict)
  125. return cls(**tensor_dict)
  126. @dataclass(frozen=True)
  127. class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
  128. """
  129. Used by the ModelRunner.
  130. """
  131. sampling_metadata: Optional["SamplingMetadata"] = None
  132. # Used for speculative decoding. We do not broadcast it because it is only
  133. # used by the driver worker.
  134. is_prompt: Optional[bool] = None
  135. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  136. tensor_dict = {
  137. "input_tokens": self.input_tokens,
  138. "input_positions": self.input_positions,
  139. "lora_requests": self.lora_requests,
  140. "lora_mapping": self.lora_mapping,
  141. "multi_modal_kwargs": self.multi_modal_kwargs,
  142. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  143. "prompt_adapter_requests": self.prompt_adapter_requests,
  144. "virtual_engine": self.virtual_engine,
  145. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  146. "finished_requests_ids": self.finished_requests_ids,
  147. }
  148. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  149. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  150. self.sampling_metadata)
  151. return tensor_dict
  152. @classmethod
  153. def from_broadcasted_tensor_dict(
  154. cls,
  155. tensor_dict: Dict[str, Any],
  156. attn_backend: Optional["AttentionBackend"] = None,
  157. ) -> "ModelInputForGPUWithSamplingMetadata":
  158. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  159. if attn_backend is not None:
  160. tensor_dict = _init_attn_metadata_from_tensor_dict(
  161. attn_backend, tensor_dict)
  162. return cls(**tensor_dict)
  163. class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
  164. """Build ModelInputForGPU from SequenceGroupMetadata."""
  165. # NOTE: ideally, we dould be using a dataclass(kw_only=True)
  166. # here, so that this can be subclassed easily, but kw_only
  167. # is not supported in python<3.10.
  168. class InterDataForSeqGroup:
  169. """Intermediate data for the current sequence group."""
  170. def simple_reinit(self):
  171. self.input_tokens[0].clear() # type: ignore
  172. self.input_positions[0].clear() # type: ignore
  173. self.mrope_input_positions = None # type: ignore
  174. self.seq_lens[0] = 0 # type: ignore
  175. self.orig_seq_lens[0] = 0 # type: ignore
  176. self.query_lens[0] = 0 # type: ignore
  177. self.context_lens[0] = 0 # type: ignore
  178. self.curr_sliding_window_blocks[0] = 0 # type: ignore
  179. self.lora_index_mapping.clear() # type: ignore
  180. self.lora_prompt_mapping.clear() # type: ignore
  181. self.lora_requests.clear() # type: ignore
  182. self.prompt_adapter_index_mapping.clear() # type: ignore
  183. self.prompt_adapter_prompt_mapping.clear() # type: ignore
  184. def __init__(
  185. self,
  186. *,
  187. # From sequence group metadata.
  188. request_id: str,
  189. seq_ids: List[int],
  190. is_prompt: bool,
  191. block_tables: Optional[Dict[int, List[int]]],
  192. computed_block_nums: List[int],
  193. n_seqs: int = 0,
  194. # Input tokens and positions.
  195. input_tokens: Optional[List[List[int]]] = None,
  196. input_positions: Optional[List[List[int]]] = None,
  197. mrope_input_positions: Optional[List[List[List[int]]]] = None,
  198. # The sequence length (may be capped to the sliding window).
  199. seq_lens: Optional[List[int]] = None,
  200. # The original sequence length (before applying sliding window).
  201. # This is used to compute slot mapping.
  202. orig_seq_lens: Optional[List[int]] = None,
  203. # The query length.
  204. query_lens: Optional[List[int]] = None,
  205. # The number of tokens that are already computed.
  206. context_lens: Optional[List[int]] = None,
  207. # The current sliding window block.
  208. curr_sliding_window_blocks: Optional[List[int]] = None,
  209. # LoRA inputs.
  210. lora_index_mapping: Optional[List[List[int]]] = None,
  211. lora_prompt_mapping: Optional[List[List[int]]] = None,
  212. lora_requests: Optional[Set[LoRARequest]] = None,
  213. # Prompt adapter inputs.
  214. prompt_adapter_index_mapping: Optional[List[int]] = None,
  215. prompt_adapter_prompt_mapping: Optional[List[int]] = None,
  216. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  217. # Multi-modal inputs.
  218. multi_modal_inputs: Optional[MultiModalInputs] = None,
  219. # Whether the prefix cache is hit (prefill only).
  220. prefix_cache_hit: bool = False,
  221. reinit: bool = False,
  222. reinit_use_defaults: bool = False,
  223. ):
  224. if reinit:
  225. assert len(self.seq_ids) == len(seq_ids) # type: ignore
  226. for i, seq_id in enumerate(seq_ids):
  227. self.seq_ids[i] = seq_id # type: ignore
  228. else:
  229. self.seq_ids = seq_ids
  230. self.request_id = request_id
  231. self.is_prompt = is_prompt
  232. self.block_tables = block_tables
  233. self.computed_block_nums = computed_block_nums
  234. self.n_seqs = n_seqs
  235. if reinit:
  236. if len(self.seq_ids) == 1 and reinit_use_defaults:
  237. self.simple_reinit()
  238. else:
  239. if input_tokens:
  240. self.input_tokens = input_tokens
  241. else:
  242. for seq_id in range(len(self.seq_ids)):
  243. self.input_tokens[seq_id].clear()
  244. if input_positions:
  245. self.input_positions = input_positions
  246. else:
  247. for seq_id in range(len(self.seq_ids)):
  248. self.input_positions[seq_id].clear()
  249. self.mrope_input_positions = None
  250. if seq_lens:
  251. self.seq_lens = seq_lens
  252. else:
  253. for seq_id in range(len(self.seq_ids)):
  254. self.seq_lens[seq_id] = 0
  255. if orig_seq_lens:
  256. self.orig_seq_lens = orig_seq_lens
  257. else:
  258. for seq_id in range(len(self.seq_ids)):
  259. self.orig_seq_lens[seq_id] = 0
  260. if query_lens:
  261. self.query_lens = query_lens
  262. else:
  263. for seq_id in range(len(self.seq_ids)):
  264. self.query_lens[seq_id] = 0
  265. if context_lens:
  266. self.context_lens = context_lens
  267. else:
  268. for seq_id in range(len(self.seq_ids)):
  269. self.context_lens[seq_id] = 0
  270. if curr_sliding_window_blocks:
  271. self.curr_sliding_window_blocks = \
  272. curr_sliding_window_blocks
  273. else:
  274. for seq_id in range(len(self.seq_ids)):
  275. self.curr_sliding_window_blocks[seq_id] = 0
  276. if lora_index_mapping:
  277. self.lora_index_mapping = lora_index_mapping
  278. else:
  279. self.lora_index_mapping.clear()
  280. if lora_prompt_mapping:
  281. self.lora_prompt_mapping = lora_prompt_mapping
  282. else:
  283. self.lora_prompt_mapping.clear()
  284. if lora_requests:
  285. self.lora_requests = lora_requests
  286. else:
  287. self.lora_requests.clear()
  288. if prompt_adapter_index_mapping:
  289. self.prompt_adapter_index_mapping = \
  290. prompt_adapter_index_mapping
  291. else:
  292. self.prompt_adapter_index_mapping.clear()
  293. if prompt_adapter_prompt_mapping:
  294. self.prompt_adapter_prompt_mapping = \
  295. prompt_adapter_prompt_mapping
  296. else:
  297. self.prompt_adapter_prompt_mapping.clear()
  298. else:
  299. self.input_tokens = input_tokens or []
  300. self.input_positions = input_positions or []
  301. self.mrope_input_positions = mrope_input_positions or None
  302. self.seq_lens = seq_lens or []
  303. self.orig_seq_lens = orig_seq_lens or []
  304. self.query_lens = query_lens or []
  305. self.context_lens = context_lens or []
  306. self.curr_sliding_window_blocks = \
  307. curr_sliding_window_blocks or []
  308. self.lora_index_mapping = lora_index_mapping or []
  309. self.lora_prompt_mapping = lora_prompt_mapping or []
  310. self.lora_requests = lora_requests or set()
  311. self.prompt_adapter_index_mapping = (
  312. prompt_adapter_index_mapping or [])
  313. self.prompt_adapter_prompt_mapping = (
  314. prompt_adapter_prompt_mapping or [])
  315. self.prompt_adapter_request = prompt_adapter_request
  316. self.multi_modal_inputs = multi_modal_inputs
  317. self.prefix_cache_hit = prefix_cache_hit
  318. self.n_seqs = len(self.seq_ids)
  319. if not reinit:
  320. self.__post_init__()
  321. def __post_init__(self):
  322. self.n_seqs = len(self.seq_ids)
  323. self.input_tokens = [[] for _ in range(self.n_seqs)]
  324. self.input_positions = [[] for _ in range(self.n_seqs)]
  325. self.mrope_input_positions = None
  326. self.seq_lens = [0] * self.n_seqs
  327. self.orig_seq_lens = [0] * self.n_seqs
  328. self.query_lens = [0] * self.n_seqs
  329. self.context_lens = [0] * self.n_seqs
  330. self.curr_sliding_window_blocks = [0] * self.n_seqs
  331. self.lora_index_mapping = []
  332. self.lora_prompt_mapping = []
  333. def gen_inter_data_builder(self, num_seqs: int):
  334. return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
  335. request_id="",
  336. seq_ids=[0] * num_seqs,
  337. is_prompt=True,
  338. block_tables=None,
  339. computed_block_nums=[])
  340. def init_cached_inter_data(self, *args, **kwargs):
  341. assert len(args) == 0
  342. assert "seq_ids" in kwargs
  343. seq_ids = kwargs["seq_ids"]
  344. num_seqs = len(seq_ids)
  345. # The inter-data cache is per model_runner
  346. inter_data_cache = self.runner.inter_data_cache
  347. if num_seqs not in inter_data_cache:
  348. inter_data_cache[num_seqs] = PyObjectCache(
  349. self.gen_inter_data_builder(num_seqs))
  350. obj = inter_data_cache[num_seqs].get_object()
  351. obj.__init__(*args, **kwargs)
  352. return obj
  353. def reset_cached_inter_data(self):
  354. for cache in self.runner.inter_data_cache.values():
  355. cache.reset()
  356. def __init__(self,
  357. runner: "GPUModelRunnerBase",
  358. finished_requests_ids: Optional[List[str]] = None):
  359. super().__init__()
  360. # Compute functions for each sequence in a sequence group.
  361. # WARNING: The order of the functions matters!
  362. self.per_seq_compute_fns = [
  363. self._compute_lens,
  364. self._compute_for_prefix_cache_hit,
  365. self._compute_for_sliding_window,
  366. self._compute_lora_input,
  367. ]
  368. # Compute functions for each sequence group.
  369. # WARNING: The order of the functions matters!
  370. self.per_seq_group_compute_fns = [
  371. self._compute_prompt_adapter_input,
  372. self._compute_multi_modal_input,
  373. ]
  374. self.runner = runner
  375. self.model_input_cls = self.runner._model_input_cls
  376. self.attn_backend = self.runner.attn_backend
  377. self.scheduler_config = self.runner.scheduler_config
  378. self.sliding_window = self.runner.sliding_window
  379. self.block_size = self.runner.block_size
  380. self.enable_lora = self.runner.lora_config is not None
  381. self.enable_prompt_adapter = (self.runner.prompt_adapter_config
  382. is not None)
  383. self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
  384. self.finished_requests_ids = finished_requests_ids
  385. self.decode_only = True
  386. # Intermediate data (data in CPU before going to GPU) for
  387. # the current sequence group.
  388. self.inter_data_list: List[
  389. ModelInputForGPUBuilder.InterDataForSeqGroup] = []
  390. # Attention metadata inputs.
  391. self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
  392. weakref.proxy(self))
  393. # Engine/Model configurations.
  394. self.chunked_prefill_enabled = (
  395. self.scheduler_config is not None
  396. and self.scheduler_config.chunked_prefill_enabled)
  397. if self.sliding_window is not None:
  398. self.sliding_window_blocks = (
  399. self.sliding_window + self.block_size - 1) // self.block_size
  400. self.block_aligned_sliding_window = \
  401. self.sliding_window_blocks * self.block_size
  402. def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
  403. seq_group_metadata: SequenceGroupMetadata):
  404. """Compute context length, sequence length and tokens
  405. for the given sequence data.
  406. """
  407. seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
  408. token_chunk_size = seq_group_metadata.token_chunk_size
  409. # Compute context length (the number of tokens that are
  410. # already computed) and sequence length (total number of tokens).
  411. seq_len = seq_data.get_len()
  412. if inter_data.is_prompt:
  413. context_len = seq_data.get_num_computed_tokens()
  414. else:
  415. # get_num_computed_tokens is incorrect for spec decoding.
  416. # So, we should have a special logic here.
  417. # TODO: Fix it.
  418. context_len = seq_len - 1
  419. seq_len = min(seq_len, context_len + token_chunk_size)
  420. # Compute tokens.
  421. if inter_data.is_prompt:
  422. tokens = seq_data.get_token_ids()
  423. if context_len != 0 or seq_len < len(tokens):
  424. tokens = tokens[context_len:seq_len]
  425. else:
  426. # Optimization. get_token_ids requires the entire copy of
  427. # tokens.
  428. tokens = seq_data.get_last_token_id()
  429. inter_data.seq_lens[seq_idx] = seq_len
  430. inter_data.orig_seq_lens[seq_idx] = seq_len
  431. inter_data.context_lens[seq_idx] = context_len
  432. if isinstance(tokens, list):
  433. inter_data.input_tokens[seq_idx].extend(tokens)
  434. else:
  435. inter_data.input_tokens[seq_idx].append(tokens)
  436. if (seq_len - context_len) == 1:
  437. inter_data.input_positions[seq_idx].append(seq_len - 1)
  438. else:
  439. inter_data.input_positions[seq_idx].extend(
  440. range(context_len, seq_len))
  441. inter_data.query_lens[
  442. seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
  443. if seq_data.mrope_position_delta is not None:
  444. if inter_data.mrope_input_positions is None:
  445. inter_data.mrope_input_positions = [None] * inter_data.n_seqs
  446. inter_data.mrope_input_positions[
  447. seq_idx] = MRotaryEmbedding.get_next_input_positions(
  448. seq_data.mrope_position_delta,
  449. context_len,
  450. seq_len,
  451. )
  452. def _compute_for_prefix_cache_hit(
  453. self, inter_data: InterDataForSeqGroup, seq_idx: int,
  454. seq_group_metadata: SequenceGroupMetadata):
  455. """Check if hit prefix cache (i.e., some blocks are already computed).
  456. If hit, update input tokens and positions to only compute the
  457. remaining blocks.
  458. """
  459. computed_block_nums = inter_data.computed_block_nums
  460. # Note that prefix caching does not support sliding window.
  461. prefix_cache_hit = (computed_block_nums is not None
  462. and len(computed_block_nums) > 0
  463. and self.sliding_window is None
  464. and inter_data.is_prompt)
  465. inter_data.prefix_cache_hit = prefix_cache_hit
  466. if not prefix_cache_hit:
  467. return
  468. assert computed_block_nums is not None
  469. # The cache hit prompt tokens in this sequence. Note that
  470. # this may be larger than the sequence length if chunked
  471. # prefill is enabled.
  472. prefix_cache_len = len(computed_block_nums) * self.block_size
  473. # The number of so far computed prompt tokens in this sequence.
  474. context_len = inter_data.context_lens[seq_idx]
  475. # The total number of prompt tokens in this sequence.
  476. # When chunked prefill is enabled, this is the token number of
  477. # computed chunks + current chunk.
  478. seq_len = inter_data.seq_lens[seq_idx]
  479. if prefix_cache_len <= context_len:
  480. # We already passed the cache hit region,
  481. # so do normal computation.
  482. pass
  483. elif context_len < prefix_cache_len < seq_len:
  484. # Partial hit. Compute the missing part.
  485. uncomputed_start = prefix_cache_len - context_len
  486. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  487. seq_idx][uncomputed_start:]
  488. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  489. seq_idx][uncomputed_start:]
  490. context_len = prefix_cache_len
  491. inter_data.context_lens[seq_idx] = context_len
  492. inter_data.query_lens[
  493. seq_idx] = inter_data.seq_lens[seq_idx] - context_len
  494. elif seq_len <= prefix_cache_len:
  495. # Full hit. Only compute the last token to avoid
  496. # erroneous behavior. FIXME: Ideally we should directly
  497. # mark all tokens as computed in the scheduler and do not
  498. # schedule this sequence, so this case should not happen.
  499. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  500. seq_idx][-1:]
  501. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  502. seq_idx][-1:]
  503. inter_data.query_lens[seq_idx] = 1
  504. inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
  505. def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
  506. seq_idx: int,
  507. seq_group_metadata: SequenceGroupMetadata):
  508. """Update seq_len and curr_sliding_window_block for the given
  509. sequence data (only required by decoding) if sliding window is enabled.
  510. """
  511. curr_sliding_window_block = 0
  512. sliding_seq_len = inter_data.seq_lens[seq_idx]
  513. if not inter_data.is_prompt and self.sliding_window is not None:
  514. # TODO: This is a hack to make sliding window work with
  515. # paged attn. We can remove it if we make paged attn kernel
  516. # to properly handle slinding window attn.
  517. curr_sliding_window_block = self.sliding_window_blocks
  518. if self.scheduler_config.use_v2_block_manager:
  519. # number of elements in last block
  520. suff_len = inter_data.seq_lens[seq_idx] % self.block_size
  521. sliding_seq_len = min(
  522. inter_data.seq_lens[seq_idx],
  523. self.block_aligned_sliding_window + suff_len)
  524. if suff_len > 0:
  525. curr_sliding_window_block += 1
  526. else:
  527. sliding_seq_len = min(inter_data.seq_lens[seq_idx],
  528. self.sliding_window)
  529. inter_data.curr_sliding_window_blocks[
  530. seq_idx] = curr_sliding_window_block
  531. inter_data.seq_lens[seq_idx] = sliding_seq_len
  532. def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
  533. seq_idx: int,
  534. seq_group_metadata: SequenceGroupMetadata):
  535. """If LoRA is enabled, compute LoRA index and prompt mapping."""
  536. if not self.enable_lora:
  537. return
  538. lora_id = seq_group_metadata.lora_int_id
  539. if lora_id > 0:
  540. inter_data.lora_requests.add(seq_group_metadata.lora_request)
  541. query_len = inter_data.query_lens[seq_idx]
  542. inter_data.lora_index_mapping.append([lora_id] * query_len)
  543. sampling_params = seq_group_metadata.sampling_params
  544. if sampling_params and sampling_params.prompt_logprobs is not None:
  545. inter_data.lora_prompt_mapping.append([lora_id] * query_len)
  546. elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
  547. inter_data.lora_prompt_mapping.append([lora_id])
  548. else:
  549. inter_data.lora_prompt_mapping.append([])
  550. def _compute_prompt_adapter_input(
  551. self, inter_data: InterDataForSeqGroup,
  552. seq_group_metadata: SequenceGroupMetadata):
  553. """If prompt adapter is enabled, compute index and prompt mapping.
  554. """
  555. # Note that when is_prompt=True, we expect only one sequence
  556. # in the group.
  557. if not self.enable_prompt_adapter:
  558. return
  559. prompt_adapter_id = seq_group_metadata.prompt_adapter_id
  560. if prompt_adapter_id <= 0 or not inter_data.is_prompt:
  561. return
  562. # We expect only one sequence in the group when is_prompt=True.
  563. assert inter_data.n_seqs == 1
  564. query_len = inter_data.query_lens[0]
  565. inter_data.prompt_adapter_request = (
  566. seq_group_metadata.prompt_adapter_request)
  567. num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
  568. inter_data.prompt_adapter_index_mapping = [
  569. prompt_adapter_id
  570. ] * num_tokens + [0] * (query_len - num_tokens)
  571. inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
  572. query_len if seq_group_metadata.sampling_params
  573. and seq_group_metadata.sampling_params.prompt_logprobs else 1)
  574. def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
  575. seq_group_metadata: SequenceGroupMetadata):
  576. """If multi-modal data is given, add it to the input."""
  577. mm_data = seq_group_metadata.multi_modal_data
  578. if not mm_data:
  579. return
  580. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  581. inter_data.multi_modal_inputs = mm_kwargs
  582. # special processing for mrope position deltas.
  583. if self.runner.model_is_mrope:
  584. image_grid_thw = mm_kwargs.get("image_grid_thw", None)
  585. video_grid_thw = mm_kwargs.get("video_grid_thw", None)
  586. assert image_grid_thw is not None or video_grid_thw is not None, (
  587. "mrope embedding type requires multi-modal input mapper "
  588. "returns 'image_grid_thw' or 'video_grid_thw'.")
  589. hf_config = self.runner.model_config.hf_config
  590. inter_data.mrope_input_positions = [None] * inter_data.n_seqs
  591. for seq_idx in range(inter_data.n_seqs):
  592. seq_data = seq_group_metadata.seq_data[
  593. inter_data.seq_ids[seq_idx]]
  594. token_ids = seq_data.get_token_ids()
  595. mrope_input_positions, mrope_position_delta = \
  596. MRotaryEmbedding.get_input_positions(
  597. token_ids,
  598. image_grid_thw=image_grid_thw,
  599. video_grid_thw=video_grid_thw,
  600. image_token_id=hf_config.image_token_id,
  601. video_token_id=hf_config.video_token_id,
  602. vision_start_token_id=hf_config.vision_start_token_id,
  603. vision_end_token_id=hf_config.vision_end_token_id,
  604. spatial_merge_size=hf_config.vision_config.
  605. spatial_merge_size,
  606. context_len=inter_data.context_lens[seq_idx],
  607. )
  608. seq_data.mrope_position_delta = mrope_position_delta
  609. inter_data.mrope_input_positions[
  610. seq_idx] = mrope_input_positions
  611. def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
  612. """Add a sequence group to the builder."""
  613. seq_ids = seq_group_metadata.seq_data.keys()
  614. n_seqs = len(seq_ids)
  615. is_prompt = seq_group_metadata.is_prompt
  616. if is_prompt:
  617. assert n_seqs == 1
  618. self.decode_only = False
  619. inter_data = self.init_cached_inter_data(
  620. request_id=seq_group_metadata.request_id,
  621. seq_ids=seq_ids,
  622. is_prompt=is_prompt,
  623. block_tables=seq_group_metadata.block_tables,
  624. computed_block_nums=seq_group_metadata.computed_block_nums,
  625. reinit=True,
  626. reinit_use_defaults=True)
  627. self.inter_data_list.append(inter_data)
  628. for seq_idx in range(n_seqs):
  629. for per_seq_fn in self.per_seq_compute_fns:
  630. per_seq_fn(inter_data, seq_idx, seq_group_metadata)
  631. for per_seq_group_fn in self.per_seq_group_compute_fns:
  632. per_seq_group_fn(inter_data, seq_group_metadata)
  633. def _use_captured_graph(self, batch_size: int,
  634. max_decode_seq_len: int) -> bool:
  635. return (self.decode_only and not self.runner.model_config.enforce_eager
  636. and batch_size <= self.runner.max_batchsize_to_capture
  637. and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
  638. def build(self) -> ModelInputForGPU:
  639. """Finalize the builder intermediate data and
  640. create on-device tensors.
  641. """
  642. # Combine and flatten intermediate data.
  643. input_tokens = []
  644. for inter_data in self.inter_data_list:
  645. for cur_input_tokens in inter_data.input_tokens:
  646. input_tokens.extend(cur_input_tokens)
  647. if not input_tokens:
  648. # This may happen when all prefill requests hit
  649. # prefix caching and there is no decode request.
  650. return self.model_input_cls()
  651. mrope_input_positions: Optional[List[List[int]]] = None
  652. if any(inter_data.mrope_input_positions is not None
  653. for inter_data in self.inter_data_list):
  654. mrope_input_positions = [[] for _ in range(3)]
  655. for idx in range(3):
  656. for inter_data in self.inter_data_list:
  657. msections = inter_data.mrope_input_positions
  658. if msections is None:
  659. for _seq_input_positions in inter_data.input_positions:
  660. mrope_input_positions[idx].extend(
  661. _seq_input_positions)
  662. else:
  663. for _seq_mrope_input_positions in msections:
  664. mrope_input_positions[idx].extend(
  665. _seq_mrope_input_positions[idx])
  666. input_positions = None
  667. else:
  668. input_positions = []
  669. for inter_data in self.inter_data_list:
  670. for cur_input_positions in inter_data.input_positions:
  671. input_positions.extend(cur_input_positions)
  672. seq_lens = []
  673. max_decode_seq_len = 0
  674. for inter_data in self.inter_data_list:
  675. seq_lens.extend(inter_data.seq_lens)
  676. if not inter_data.is_prompt:
  677. max_decode_seq_len = max(max_decode_seq_len,
  678. max(inter_data.seq_lens))
  679. query_lens = []
  680. for inter_data in self.inter_data_list:
  681. query_lens.extend(inter_data.query_lens)
  682. # Mapping from request IDs to sequence IDs. Used for Jamba models
  683. # that manages the cache by itself.
  684. request_ids_to_seq_ids = {
  685. data.request_id: data.seq_ids
  686. for data in self.inter_data_list
  687. }
  688. batch_size = len(input_tokens)
  689. use_captured_graph = self._use_captured_graph(batch_size,
  690. max_decode_seq_len)
  691. # If cuda graph can be used, pad tensors accordingly.
  692. # See `capture_model` API for more details.
  693. # Aphrodite uses cuda graph only for decoding requests.
  694. cuda_graph_pad_size = -1
  695. if use_captured_graph:
  696. graph_batch_size = _get_graph_batch_size(batch_size)
  697. assert graph_batch_size >= batch_size
  698. cuda_graph_pad_size = graph_batch_size - batch_size
  699. batch_size = graph_batch_size
  700. # Tokens and positions.
  701. if cuda_graph_pad_size:
  702. input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
  703. assert self.runner.device is not None
  704. input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
  705. self.runner.device,
  706. self.runner.pin_memory)
  707. if mrope_input_positions is not None:
  708. for idx in range(3):
  709. mrope_input_positions[idx].extend(
  710. itertools.repeat(0, cuda_graph_pad_size))
  711. input_positions_tensor = async_tensor_h2d(mrope_input_positions,
  712. torch.long,
  713. self.runner.device,
  714. self.runner.pin_memory)
  715. else:
  716. input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
  717. input_positions_tensor = async_tensor_h2d(input_positions,
  718. torch.long,
  719. self.runner.device,
  720. self.runner.pin_memory)
  721. # Sequence and query lengths.
  722. if cuda_graph_pad_size:
  723. seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
  724. # Attention metadata.
  725. attn_metadata = self.attn_metadata_builder.build(
  726. seq_lens, query_lens, cuda_graph_pad_size, batch_size)
  727. # LoRA data.
  728. lora_requests = set()
  729. lora_mapping = None
  730. if self.enable_lora:
  731. lora_requests = set(r for data in self.inter_data_list
  732. for r in data.lora_requests)
  733. lora_index_mapping = flatten_2d_lists([
  734. flatten_2d_lists(inter_data.lora_index_mapping)
  735. for inter_data in self.inter_data_list
  736. ])
  737. if cuda_graph_pad_size:
  738. lora_index_mapping.extend(
  739. itertools.repeat(0, cuda_graph_pad_size))
  740. lora_prompt_mapping = flatten_2d_lists([
  741. flatten_2d_lists(inter_data.lora_prompt_mapping)
  742. for inter_data in self.inter_data_list
  743. ])
  744. lora_mapping = LoRAMapping(
  745. **dict(index_mapping=lora_index_mapping,
  746. prompt_mapping=lora_prompt_mapping,
  747. is_prefill=not self.decode_only))
  748. # Prompt adapter data.
  749. prompt_adapter_requests: Set[PromptAdapterRequest] = set()
  750. prompt_adapter_mapping = None
  751. if self.enable_prompt_adapter:
  752. prompt_adapter_requests = set(
  753. data.prompt_adapter_request for data in self.inter_data_list
  754. if data.prompt_adapter_request is not None)
  755. prompt_adapter_index_mapping = flatten_2d_lists([
  756. inter_data.prompt_adapter_index_mapping
  757. for inter_data in self.inter_data_list
  758. ])
  759. if cuda_graph_pad_size:
  760. prompt_adapter_index_mapping.extend(
  761. itertools.repeat(0, cuda_graph_pad_size))
  762. prompt_adapter_prompt_mapping = flatten_2d_lists([
  763. inter_data.prompt_adapter_prompt_mapping
  764. for inter_data in self.inter_data_list
  765. ])
  766. prompt_adapter_mapping = PromptAdapterMapping(
  767. prompt_adapter_index_mapping,
  768. prompt_adapter_prompt_mapping,
  769. )
  770. # Multi-modal data.
  771. multi_modal_inputs_list = [
  772. data.multi_modal_inputs for data in self.inter_data_list
  773. if data.multi_modal_inputs is not None
  774. ]
  775. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  776. return self.model_input_cls(
  777. input_tokens=input_tokens_tensor,
  778. input_positions=input_positions_tensor,
  779. attn_metadata=attn_metadata,
  780. seq_lens=seq_lens,
  781. query_lens=query_lens,
  782. lora_mapping=lora_mapping,
  783. lora_requests=lora_requests,
  784. multi_modal_kwargs=multi_modal_kwargs,
  785. request_ids_to_seq_ids=request_ids_to_seq_ids,
  786. finished_requests_ids=self.finished_requests_ids,
  787. prompt_adapter_mapping=prompt_adapter_mapping,
  788. prompt_adapter_requests=prompt_adapter_requests)
  789. class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
  790. """
  791. Helper class for shared methods between GPU model runners.
  792. """
  793. _model_input_cls: Type[TModelInputForGPU]
  794. _builder_cls: Type[ModelInputForGPUBuilder]
  795. def __init__(
  796. self,
  797. model_config: ModelConfig,
  798. parallel_config: ParallelConfig,
  799. scheduler_config: SchedulerConfig,
  800. device_config: DeviceConfig,
  801. cache_config: CacheConfig,
  802. load_config: LoadConfig,
  803. lora_config: Optional[LoRAConfig],
  804. kv_cache_dtype: Optional[str] = "auto",
  805. is_driver_worker: bool = False,
  806. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  807. return_hidden_states: bool = False,
  808. input_registry: InputRegistry = INPUT_REGISTRY,
  809. mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
  810. tp_rank: int = 0,
  811. ):
  812. self.model_config = model_config
  813. self.parallel_config = parallel_config
  814. self.scheduler_config = scheduler_config
  815. self.device_config = device_config
  816. self.cache_config = cache_config
  817. self.lora_config = lora_config
  818. self.load_config = load_config
  819. self.is_driver_worker = is_driver_worker
  820. self.prompt_adapter_config = prompt_adapter_config
  821. self.return_hidden_states = return_hidden_states
  822. self.device = self.device_config.device
  823. self.pin_memory = is_pin_memory_available()
  824. self.tp_rank = tp_rank
  825. self.kv_cache_dtype = kv_cache_dtype
  826. self.sliding_window = model_config.get_sliding_window()
  827. self.block_size = cache_config.block_size
  828. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  829. self.max_batchsize_to_capture = _get_max_graph_batch_size(
  830. self.scheduler_config.max_num_seqs)
  831. self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
  832. {} for _ in range(self.parallel_config.pipeline_parallel_size)
  833. ]
  834. self.graph_memory_pool: Optional[Tuple[
  835. int, int]] = None # Set during graph capture.
  836. self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
  837. parallel_config)
  838. # When using CUDA graph, the input block tables must be padded to
  839. # max_seq_len_to_capture. However, creating the block table in
  840. # Python can be expensive. To optimize this, we cache the block table
  841. # in numpy and only copy the actual input content at every iteration.
  842. # The shape of the cached block table will be
  843. # (max batch size to capture, max context len to capture / block size).
  844. self.graph_block_tables = np.zeros(
  845. (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
  846. dtype=np.int32)
  847. self.attn_backend = get_attn_backend(
  848. self.model_config.get_head_size(),
  849. self.model_config.get_sliding_window(),
  850. self.model_config.dtype,
  851. self.kv_cache_dtype,
  852. self.block_size,
  853. self.model_config.is_attention_free(),
  854. )
  855. if self.attn_backend:
  856. self.attn_state = self.attn_backend.get_state_cls()(
  857. weakref.proxy(self))
  858. else:
  859. self.attn_state = CommonAttentionState(weakref.proxy(self))
  860. # Multi-modal data support
  861. self.input_registry = input_registry
  862. self.mm_registry = mm_registry
  863. self.multi_modal_input_mapper = mm_registry \
  864. .create_input_mapper(model_config)
  865. self.mm_registry.init_mm_limits_per_prompt(self.model_config)
  866. # Lazy initialization
  867. self.model: nn.Module # Set after load_model
  868. # Set after load_model.
  869. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  870. self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
  871. set_cpu_offload_max_bytes(
  872. int(self.cache_config.cpu_offload_gb * 1024**3))
  873. # Used to cache python objects
  874. self.inter_data_cache: Dict[int, PyObjectCache] = {}
  875. self.sampling_metadata_cache: SamplingMetadataCache = \
  876. SamplingMetadataCache()
  877. def load_model(self) -> None:
  878. tp = get_tensor_model_parallel_world_size()
  879. rank = get_tensor_model_parallel_rank()
  880. if rank == 0:
  881. logger.info(f"Loading model {self.model_config.model}...")
  882. with CudaMemoryProfiler() as m:
  883. # measure the time it takes to load the model
  884. start_time = time.time()
  885. self.model = get_model(model_config=self.model_config,
  886. device_config=self.device_config,
  887. load_config=self.load_config,
  888. lora_config=self.lora_config,
  889. parallel_config=self.parallel_config,
  890. scheduler_config=self.scheduler_config,
  891. cache_config=self.cache_config)
  892. end_time = time.time()
  893. self.model_memory_usage = m.consumed_memory
  894. total_time = end_time - start_time
  895. if tp > 1:
  896. if rank == 0:
  897. logger.info(f"Model loaded in {total_time:.2f} seconds.")
  898. logger.info(
  899. "Total model weights memory usage: "
  900. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  901. else:
  902. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  903. logger.info(
  904. "Total model weights memory usage: "
  905. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  906. if self.lora_config:
  907. assert supports_lora(self.model), "Model does not support LoRA"
  908. assert not supports_multimodal(
  909. self.model
  910. ), "To be tested: multimodal language model with LoRA settings."
  911. self.lora_manager = LRUCacheWorkerLoRAManager(
  912. self.scheduler_config.max_num_seqs,
  913. self.scheduler_config.max_num_batched_tokens,
  914. self.vocab_size,
  915. self.lora_config,
  916. self.device,
  917. self.model.embedding_modules,
  918. self.model.embedding_padding_modules,
  919. max_position_embeddings=self.model.config.
  920. max_position_embeddings,
  921. )
  922. self.model = self.lora_manager.create_lora_manager(self.model)
  923. if self.prompt_adapter_config:
  924. self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
  925. self.scheduler_config.max_num_seqs,
  926. self.scheduler_config.max_num_batched_tokens, self.device,
  927. self.prompt_adapter_config)
  928. self.model = (
  929. self.prompt_adapter_manager.create_prompt_adapter_manager(
  930. self.model))
  931. if self.kv_cache_dtype == "fp8" and is_hip():
  932. # Currently only ROCm accepts kv-cache scaling factors
  933. # via quantization_param_path and this will be deprecated
  934. # in the future.
  935. if self.model_config.quantization_param_path is not None:
  936. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  937. warnings.warn(
  938. "Loading kv cache scaling factor from JSON is "
  939. "deprecated and will be removed. Please include "
  940. "kv cache scaling factors in the model checkpoint.",
  941. FutureWarning,
  942. stacklevel=2)
  943. self.model.load_kv_cache_scales(
  944. self.model_config.quantization_param_path)
  945. logger.info(
  946. "Loaded KV cache scaling factors from ",
  947. f"{self.model_config.quantization_param_path}")
  948. else:
  949. raise RuntimeError(
  950. "Using FP8 KV cache and scaling factors provided but "
  951. f"model {self.model.__class__} does not support loading"
  952. " scaling factors.", )
  953. else:
  954. logger.warning(
  955. "Using FP8 KV cache but no scaling factors "
  956. "provided. Defaulting to scaling factors of 1.0. "
  957. "This may lead to less accurate results!")
  958. if envs.APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo:
  959. logger.info("Compiling the model using torch.compile...")
  960. start_time = time.time()
  961. self.model = torch.compile(
  962. self.model,
  963. fullgraph=envs.APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE,
  964. backend="eager")
  965. end_time = time.time()
  966. logger.info(
  967. f"Model compiled in {end_time - start_time:.2f} seconds.")
  968. def get_model_memory_usage(self):
  969. return self.model_memory_usage
  970. def save_sharded_state(
  971. self,
  972. path: str,
  973. pattern: Optional[str] = None,
  974. max_size: Optional[int] = None,
  975. ) -> None:
  976. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  977. ShardedStateLoader.save_model(
  978. self.model,
  979. path,
  980. pattern=pattern,
  981. max_size=max_size,
  982. )
  983. def save_tensorized_model(
  984. self,
  985. tensorizer_config: TensorizerConfig,
  986. ) -> None:
  987. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  988. TensorizerLoader.save_model(
  989. self.model,
  990. tensorizer_config=tensorizer_config,
  991. )
  992. def get_max_block_per_batch(self) -> int:
  993. block_size = self.block_size
  994. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  995. def _prepare_model_input_tensors(
  996. self,
  997. seq_group_metadata_list: List[SequenceGroupMetadata],
  998. finished_requests_ids: Optional[List[str]] = None
  999. ) -> TModelInputForGPU:
  1000. """Helper method to prepare the model input based on a given sequence
  1001. group. Prepares metadata needed for the base model forward pass but not
  1002. metadata for possible additional steps, e.g., sampling.
  1003. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  1004. The result tensors and data structure also batches input in prefill
  1005. -> decode order. For example,
  1006. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1007. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1008. If cuda graph is required, this API automatically pads inputs.
  1009. """
  1010. builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
  1011. for seq_group_metadata in seq_group_metadata_list:
  1012. builder.add_seq_group(seq_group_metadata)
  1013. builder.reset_cached_inter_data()
  1014. return builder.build() # type: ignore
  1015. @torch.inference_mode()
  1016. def profile_run(self) -> None:
  1017. rank = get_tensor_model_parallel_rank()
  1018. if rank == 0:
  1019. logger.info("Profiling peak memory usage...")
  1020. # Enable top-k sampling to reflect the accurate memory usage.
  1021. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  1022. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  1023. max_num_seqs = self.scheduler_config.max_num_seqs
  1024. single_seq_mode = self.scheduler_config.single_user_mode
  1025. if single_seq_mode and rank == 0:
  1026. logger.info("Running in single sequence profiling mode")
  1027. # This represents the maximum number of different requests
  1028. # that will have unique loras, an therefore the max amount of memory
  1029. # consumption create dummy lora request copies from the lora request
  1030. # passed in, which contains a lora from the lora warmup path.
  1031. dummy_lora_requests: List[LoRARequest] = []
  1032. dummy_lora_requests_per_seq: List[LoRARequest] = []
  1033. if self.lora_config:
  1034. assert self.lora_manager is not None
  1035. with self.lora_manager.dummy_lora_cache():
  1036. for idx in range(self.lora_config.max_loras):
  1037. lora_id = idx + 1
  1038. dummy_lora_request = LoRARequest(
  1039. lora_name=f"warmup_{lora_id}",
  1040. lora_int_id=lora_id,
  1041. lora_local_path="/not/a/real/path",
  1042. )
  1043. self.lora_manager.add_dummy_lora(dummy_lora_request,
  1044. rank=LORA_WARMUP_RANK)
  1045. dummy_lora_requests.append(dummy_lora_request)
  1046. if single_seq_mode:
  1047. dummy_lora_requests_per_seq = [dummy_lora_requests[0]]
  1048. else:
  1049. dummy_lora_requests_per_seq = [
  1050. dummy_lora_requests[idx % len(dummy_lora_requests)]
  1051. for idx in range(max_num_seqs)
  1052. ]
  1053. # Profile memory usage with max_num_sequences sequences and the total
  1054. # number of tokens equal to max_num_batched_tokens.
  1055. seqs: List[SequenceGroupMetadata] = []
  1056. # Additional GPU memory may be needed for multi-modal encoding, which
  1057. # needs to be accounted for when calculating the GPU blocks for
  1058. # Aphrodite blocker manager.
  1059. # To exercise the worst scenario for GPU memory consumption,
  1060. # the number of seqs (batch_size) is chosen to maximize the number
  1061. # of images processed.
  1062. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
  1063. self.model_config)
  1064. if max_mm_tokens > 0 and not single_seq_mode:
  1065. max_num_seqs_orig = max_num_seqs
  1066. max_num_seqs = min(max_num_seqs,
  1067. max_num_batched_tokens // max_mm_tokens)
  1068. if max_num_seqs < 1:
  1069. expr = (f"min({max_num_seqs_orig}, "
  1070. f"{max_num_batched_tokens} // {max_mm_tokens})")
  1071. logger.warning(
  1072. f"Computed max_num_seqs ({expr}) to be less than 1. "
  1073. "Setting it to the minimum value of 1.")
  1074. max_num_seqs = 1
  1075. batch_size = 0
  1076. if single_seq_mode:
  1077. seq_len = max_num_batched_tokens
  1078. batch_size = seq_len
  1079. seq_data, dummy_multi_modal_data = self.input_registry \
  1080. .dummy_data_for_profiling(self.model_config,
  1081. seq_len,
  1082. self.mm_registry)
  1083. seq = SequenceGroupMetadata(
  1084. request_id="0",
  1085. is_prompt=True,
  1086. seq_data={0: seq_data},
  1087. sampling_params=sampling_params,
  1088. block_tables=None,
  1089. lora_request=dummy_lora_requests_per_seq[0]
  1090. if dummy_lora_requests_per_seq else None,
  1091. multi_modal_data=dummy_multi_modal_data,
  1092. )
  1093. seqs.append(seq)
  1094. else:
  1095. # Original multi-sequence profiling logic
  1096. for group_id in range(max_num_seqs):
  1097. seq_len = (max_num_batched_tokens // max_num_seqs +
  1098. (group_id < max_num_batched_tokens % max_num_seqs))
  1099. batch_size += seq_len
  1100. seq_data, dummy_multi_modal_data = self.input_registry \
  1101. .dummy_data_for_profiling(self.model_config,
  1102. seq_len,
  1103. self.mm_registry)
  1104. seq = SequenceGroupMetadata(
  1105. request_id=str(group_id),
  1106. is_prompt=True,
  1107. seq_data={group_id: seq_data},
  1108. sampling_params=sampling_params,
  1109. block_tables=None,
  1110. lora_request=dummy_lora_requests_per_seq[group_id]
  1111. if dummy_lora_requests_per_seq else None,
  1112. multi_modal_data=dummy_multi_modal_data,
  1113. )
  1114. seqs.append(seq)
  1115. # Run the model with the dummy inputs.
  1116. num_layers = self.model_config.get_num_layers(self.parallel_config)
  1117. kv_caches = [None] * num_layers
  1118. finished_requests_ids = [seq.request_id for seq in seqs]
  1119. model_input = self.prepare_model_input(
  1120. seqs, finished_requests_ids=finished_requests_ids)
  1121. intermediate_tensors = None
  1122. if not get_pp_group().is_first_rank:
  1123. intermediate_tensors = self.model.make_empty_intermediate_tensors(
  1124. batch_size=batch_size,
  1125. dtype=self.model_config.dtype,
  1126. device=self.device)
  1127. self.execute_model(model_input, kv_caches, intermediate_tensors)
  1128. torch.cuda.synchronize()
  1129. return
  1130. def remove_all_loras(self):
  1131. if not self.lora_manager:
  1132. raise RuntimeError("LoRA is not enabled.")
  1133. self.lora_manager.remove_all_adapters()
  1134. def set_active_loras(self, lora_requests: Set[LoRARequest],
  1135. lora_mapping: LoRAMapping) -> None:
  1136. if not self.lora_manager:
  1137. raise RuntimeError("LoRA is not enabled.")
  1138. self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
  1139. def add_lora(self, lora_request: LoRARequest) -> bool:
  1140. if not self.lora_manager:
  1141. raise RuntimeError("LoRA is not enabled.")
  1142. return self.lora_manager.add_adapter(lora_request)
  1143. def remove_lora(self, lora_id: int) -> bool:
  1144. if not self.lora_manager:
  1145. raise RuntimeError("LoRA is not enabled.")
  1146. return self.lora_manager.remove_adapter(lora_id)
  1147. def pin_lora(self, lora_id: int) -> bool:
  1148. if not self.lora_manager:
  1149. raise RuntimeError("LoRA is not enabled.")
  1150. return self.lora_manager.pin_adapter(lora_id)
  1151. def list_loras(self) -> Set[int]:
  1152. if not self.lora_manager:
  1153. raise RuntimeError("LoRA is not enabled.")
  1154. return self.lora_manager.list_adapters()
  1155. def remove_all_prompt_adapters(self):
  1156. if not self.prompt_adapter_manager:
  1157. raise RuntimeError("PromptAdapter is not enabled.")
  1158. self.prompt_adapter_manager.remove_all_adapters()
  1159. def set_active_prompt_adapters(
  1160. self, prompt_adapter_requests: Set[PromptAdapterRequest],
  1161. prompt_adapter_mapping: PromptAdapterMapping) -> None:
  1162. if not self.prompt_adapter_manager:
  1163. raise RuntimeError("PromptAdapter is not enabled.")
  1164. self.prompt_adapter_manager.set_active_adapters(
  1165. prompt_adapter_requests, prompt_adapter_mapping)
  1166. def add_prompt_adapter(
  1167. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  1168. if not self.prompt_adapter_manager:
  1169. raise RuntimeError("PromptAdapter is not enabled.")
  1170. return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
  1171. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1172. if not self.prompt_adapter_manager:
  1173. raise RuntimeError("PromptAdapter is not enabled.")
  1174. return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
  1175. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1176. if not self.prompt_adapter_manager:
  1177. raise RuntimeError("PromptAdapter is not enabled.")
  1178. return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
  1179. def list_prompt_adapters(self) -> Set[int]:
  1180. if not self.prompt_adapter_manager:
  1181. raise RuntimeError("PromptAdapter is not enabled.")
  1182. return self.prompt_adapter_manager.list_adapters()
  1183. @property
  1184. def model_is_mrope(self) -> bool:
  1185. """Detect if the model has "mrope" rope_scaling type.
  1186. mrope requires keep "rope_deltas" between prompt and decoding phases."""
  1187. rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
  1188. if rope_scaling is None:
  1189. return False
  1190. return rope_scaling.get("type", None) == "mrope"
  1191. @torch.inference_mode()
  1192. def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
  1193. """Cuda graph capture a model.
  1194. Note that CUDA graph's performance gain is negligible if number
  1195. of batched tokens are larger than 200. And since CUDA graph
  1196. requires fixed sized tensors, supporting large/variable batch
  1197. size requires high GPU memory overhead. Thus, Aphrodite only captures
  1198. decoding requests. Mixed batch (chunked prefill + decoding) or
  1199. prefill requests are not captured.
  1200. Since it is used for decoding-only, it assumes there's only 1 token
  1201. per sequence in the batch.
  1202. """
  1203. tp_rank = get_tensor_model_parallel_rank()
  1204. assert not self.model_config.enforce_eager
  1205. if tp_rank == 0:
  1206. logger.info(
  1207. "Capturing the model for CUDA graphs. This may lead to "
  1208. "unexpected consequences if the model is not static. To "
  1209. "run the model in eager mode, set 'enforce_eager=True' or "
  1210. "use '--enforce-eager' in the CLI.")
  1211. logger.info(
  1212. "CUDA graphs can take additional 1~3 GiB memory per GPU. "
  1213. "If you are running out of memory, consider decreasing "
  1214. "`gpu_memory_utilization` or enforcing eager mode. "
  1215. "You can also reduce the `max_num_seqs` as needed "
  1216. "to decrease memory usage.")
  1217. start_time = time.perf_counter()
  1218. # Prepare dummy inputs. These will be reused for all batch sizes.
  1219. max_batch_size = self.max_batchsize_to_capture
  1220. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  1221. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  1222. if self.model_is_mrope:
  1223. input_positions = torch.tile(input_positions, (3, 1))
  1224. # Prepare dummy previous_hidden_states only if needed by the model.
  1225. # This is used by draft models such as EAGLE.
  1226. previous_hidden_states = None
  1227. if "previous_hidden_states" in inspect.signature(
  1228. self.model.forward).parameters:
  1229. previous_hidden_states = torch.empty(
  1230. [max_batch_size,
  1231. self.model_config.get_hidden_size()],
  1232. dtype=self.model_config.dtype,
  1233. device=self.device)
  1234. intermediate_inputs = None
  1235. if not get_pp_group().is_first_rank:
  1236. intermediate_inputs = self.model.make_empty_intermediate_tensors(
  1237. batch_size=max_batch_size,
  1238. dtype=self.model_config.dtype,
  1239. device=self.device)
  1240. # Prepare buffer for outputs. These will be reused for all batch sizes.
  1241. # It will be filled after the first graph capture.
  1242. hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
  1243. None
  1244. ] * self.parallel_config.pipeline_parallel_size
  1245. graph_batch_size = self.max_batchsize_to_capture
  1246. batch_size_capture_list = [
  1247. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  1248. ]
  1249. with self.attn_state.graph_capture(
  1250. max_batch_size), graph_capture() as graph_capture_context:
  1251. # NOTE: Capturing the largest batch size first may help reduce the
  1252. # memory usage of CUDA graph.
  1253. for virtual_engine in range(
  1254. self.parallel_config.pipeline_parallel_size):
  1255. for batch_size in reversed(batch_size_capture_list):
  1256. attn_metadata = (
  1257. self.attn_state.graph_capture_get_metadata_for_batch(
  1258. batch_size))
  1259. if self.lora_config:
  1260. lora_mapping = LoRAMapping(
  1261. **dict(index_mapping=[0] * batch_size,
  1262. prompt_mapping=[0] * batch_size,
  1263. is_prefill=False))
  1264. self.set_active_loras(set(), lora_mapping)
  1265. if self.prompt_adapter_config:
  1266. prompt_adapter_mapping = PromptAdapterMapping(
  1267. [-1] * batch_size,
  1268. [-1] * batch_size,
  1269. )
  1270. self.set_active_prompt_adapters(
  1271. set(), prompt_adapter_mapping)
  1272. graph_runner = CUDAGraphRunner(
  1273. self.model, self.attn_backend.get_name(),
  1274. self.attn_state.graph_clone(batch_size))
  1275. capture_inputs = {
  1276. "input_ids":
  1277. input_tokens[:batch_size],
  1278. "positions":
  1279. input_positions[..., :batch_size],
  1280. "hidden_or_intermediate_states":
  1281. hidden_or_intermediate_states[
  1282. virtual_engine] # type: ignore
  1283. [:batch_size]
  1284. if hidden_or_intermediate_states[virtual_engine]
  1285. is not None else None,
  1286. "intermediate_inputs":
  1287. intermediate_inputs[:batch_size]
  1288. if intermediate_inputs is not None else None,
  1289. "kv_caches":
  1290. kv_caches[virtual_engine],
  1291. "attn_metadata":
  1292. attn_metadata,
  1293. "memory_pool":
  1294. self.graph_memory_pool,
  1295. "stream":
  1296. graph_capture_context.stream
  1297. }
  1298. if previous_hidden_states is not None:
  1299. capture_inputs[
  1300. "previous_hidden_states"] = previous_hidden_states[:
  1301. batch_size]
  1302. if self.has_seqlen_agnostic:
  1303. # Only used by Mamba-based models CUDA graph atm (Jamba)
  1304. capture_inputs.update({
  1305. "seqlen_agnostic_capture_inputs":
  1306. self.model.get_seqlen_agnostic_capture_inputs(
  1307. batch_size)
  1308. })
  1309. graph_runner.capture(**capture_inputs)
  1310. self.graph_memory_pool = graph_runner.graph.pool()
  1311. self.graph_runners[virtual_engine][batch_size] = (
  1312. graph_runner)
  1313. end_time = time.perf_counter()
  1314. elapsed_time = end_time - start_time
  1315. # This usually takes < 10 seconds.
  1316. if tp_rank == 0:
  1317. logger.info(f"Graph capturing finished in {elapsed_time:.2f} secs")
  1318. @property
  1319. def vocab_size(self) -> int:
  1320. return self.model_config.get_vocab_size()
  1321. class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
  1322. """
  1323. GPU model runner with sampling step.
  1324. """
  1325. _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
  1326. ModelInputForGPUWithSamplingMetadata)
  1327. _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
  1328. def make_model_input_from_broadcasted_tensor_dict(
  1329. self,
  1330. tensor_dict: Dict[str, Any],
  1331. ) -> ModelInputForGPUWithSamplingMetadata:
  1332. model_input = \
  1333. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  1334. tensor_dict,
  1335. attn_backend=self.attn_backend,
  1336. )
  1337. return model_input
  1338. def prepare_model_input(
  1339. self,
  1340. seq_group_metadata_list: List[SequenceGroupMetadata],
  1341. virtual_engine: int = 0,
  1342. finished_requests_ids: Optional[List[str]] = None
  1343. ) -> ModelInputForGPUWithSamplingMetadata:
  1344. """Prepare the model input based on a given sequence group, including
  1345. metadata for the sampling step.
  1346. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  1347. The result tensors and data structure also batches input in prefill
  1348. -> decode order. For example,
  1349. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1350. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1351. If cuda graph is required, this API automatically pads inputs.
  1352. """
  1353. model_input = self._prepare_model_input_tensors(
  1354. seq_group_metadata_list, finished_requests_ids)
  1355. if get_pp_group().is_last_rank:
  1356. # Sampling metadata is only required for the final pp group
  1357. generators = self.get_generators(finished_requests_ids)
  1358. sampling_metadata = SamplingMetadata.prepare(
  1359. seq_group_metadata_list, model_input.seq_lens,
  1360. model_input.query_lens, self.device, self.pin_memory,
  1361. generators, self.sampling_metadata_cache)
  1362. else:
  1363. sampling_metadata = None
  1364. is_prompt = (seq_group_metadata_list[0].is_prompt
  1365. if seq_group_metadata_list else None)
  1366. return dataclasses.replace(model_input,
  1367. sampling_metadata=sampling_metadata,
  1368. is_prompt=is_prompt,
  1369. virtual_engine=virtual_engine)
  1370. @torch.inference_mode()
  1371. @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
  1372. def execute_model(
  1373. self,
  1374. model_input: ModelInputForGPUWithSamplingMetadata,
  1375. kv_caches: List[torch.Tensor],
  1376. intermediate_tensors: Optional[IntermediateTensors] = None,
  1377. num_steps: int = 1,
  1378. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  1379. if num_steps > 1:
  1380. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  1381. if self.lora_config:
  1382. assert model_input.lora_requests is not None
  1383. assert model_input.lora_mapping is not None
  1384. self.set_active_loras(model_input.lora_requests,
  1385. model_input.lora_mapping)
  1386. if self.prompt_adapter_config:
  1387. assert model_input.prompt_adapter_requests is not None
  1388. assert model_input.prompt_adapter_mapping is not None
  1389. self.set_active_prompt_adapters(
  1390. model_input.prompt_adapter_requests,
  1391. model_input.prompt_adapter_mapping)
  1392. self.attn_state.begin_forward(model_input)
  1393. # Currently cuda graph is only supported by the decode phase.
  1394. assert model_input.attn_metadata is not None
  1395. prefill_meta = model_input.attn_metadata.prefill_metadata
  1396. decode_meta = model_input.attn_metadata.decode_metadata
  1397. # TODO: We can remove this once all
  1398. # virtual engines share the same kv cache.
  1399. virtual_engine = model_input.virtual_engine
  1400. if prefill_meta is None and decode_meta.use_cuda_graph:
  1401. assert model_input.input_tokens is not None
  1402. graph_batch_size = model_input.input_tokens.shape[0]
  1403. model_executable = self.graph_runners[virtual_engine][
  1404. graph_batch_size]
  1405. else:
  1406. model_executable = self.model
  1407. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  1408. seqlen_agnostic_kwargs = {
  1409. "finished_requests_ids": model_input.finished_requests_ids,
  1410. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  1411. } if self.has_seqlen_agnostic else {}
  1412. hidden_or_intermediate_states = model_executable(
  1413. input_ids=model_input.input_tokens,
  1414. positions=model_input.input_positions,
  1415. kv_caches=kv_caches,
  1416. attn_metadata=model_input.attn_metadata,
  1417. intermediate_tensors=intermediate_tensors,
  1418. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  1419. device=self.device),
  1420. **seqlen_agnostic_kwargs,
  1421. )
  1422. # Compute the logits in the last pipeline stage.
  1423. if not get_pp_group().is_last_rank:
  1424. return hidden_or_intermediate_states
  1425. logits = self.model.compute_logits(hidden_or_intermediate_states,
  1426. model_input.sampling_metadata)
  1427. if not self.is_driver_worker:
  1428. return []
  1429. if model_input.async_callback is not None:
  1430. model_input.async_callback()
  1431. # Sample the next token.
  1432. output: SamplerOutput = self.model.sample(
  1433. logits=logits,
  1434. sampling_metadata=model_input.sampling_metadata,
  1435. )
  1436. if self.return_hidden_states:
  1437. # we only need to pass hidden states of most recent token
  1438. assert model_input.sampling_metadata is not None
  1439. indices = model_input.sampling_metadata.selected_token_indices
  1440. if model_input.is_prompt:
  1441. hidden_states = hidden_or_intermediate_states.index_select(
  1442. 0, indices)
  1443. output.prefill_hidden_states = hidden_or_intermediate_states
  1444. elif decode_meta.use_cuda_graph:
  1445. hidden_states = hidden_or_intermediate_states[:len(indices)]
  1446. else:
  1447. hidden_states = hidden_or_intermediate_states
  1448. output.hidden_states = hidden_states
  1449. return [output]
  1450. class CUDAGraphRunner:
  1451. def __init__(self, model: nn.Module, backend_name: str,
  1452. attn_state: AttentionState):
  1453. self.model = model
  1454. self.backend_name = backend_name
  1455. self.attn_state = attn_state
  1456. self.input_buffers: Dict[str, torch.Tensor] = {}
  1457. self.output_buffers: Dict[str, torch.Tensor] = {}
  1458. self._graph: Optional[torch.cuda.CUDAGraph] = None
  1459. @property
  1460. def graph(self):
  1461. assert self._graph is not None
  1462. return self._graph
  1463. def capture(
  1464. self,
  1465. input_ids: torch.Tensor,
  1466. positions: torch.Tensor,
  1467. hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
  1468. torch.Tensor]],
  1469. intermediate_inputs: Optional[IntermediateTensors],
  1470. kv_caches: List[torch.Tensor],
  1471. attn_metadata: AttentionMetadata,
  1472. memory_pool: Optional[Tuple[int, int]],
  1473. stream: torch.cuda.Stream,
  1474. **kwargs,
  1475. ) -> Union[torch.Tensor, IntermediateTensors]:
  1476. assert self._graph is None
  1477. # Run the model a few times without capturing the graph.
  1478. # This is to make sure that the captured graph does not include the
  1479. # kernel launches for initial benchmarking (e.g., Triton autotune).
  1480. # Note one iteration is not enough for torch.jit.script
  1481. for _ in range(_NUM_WARMUP_ITERS):
  1482. self.model(
  1483. input_ids=input_ids,
  1484. positions=positions,
  1485. kv_caches=kv_caches,
  1486. attn_metadata=attn_metadata,
  1487. intermediate_tensors=intermediate_inputs,
  1488. **kwargs,
  1489. )
  1490. torch.cuda.synchronize()
  1491. # Capture the graph.
  1492. self._graph = torch.cuda.CUDAGraph()
  1493. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  1494. output_hidden_or_intermediate_states = self.model(
  1495. input_ids=input_ids,
  1496. positions=positions,
  1497. kv_caches=kv_caches,
  1498. attn_metadata=attn_metadata,
  1499. intermediate_tensors=intermediate_inputs,
  1500. **kwargs,
  1501. )
  1502. if hidden_or_intermediate_states is not None:
  1503. if get_pp_group().is_last_rank:
  1504. hidden_or_intermediate_states.copy_(
  1505. output_hidden_or_intermediate_states)
  1506. else:
  1507. for key in hidden_or_intermediate_states.tensors:
  1508. hidden_or_intermediate_states[key].copy_(
  1509. output_hidden_or_intermediate_states[key])
  1510. else:
  1511. hidden_or_intermediate_states = (
  1512. output_hidden_or_intermediate_states)
  1513. del output_hidden_or_intermediate_states
  1514. # make sure `output_hidden_states` is deleted
  1515. # in the graph's memory pool
  1516. gc.collect()
  1517. torch.cuda.synchronize()
  1518. # Save the input and output buffers.
  1519. self.input_buffers = {
  1520. "input_ids": input_ids,
  1521. "positions": positions,
  1522. "kv_caches": kv_caches,
  1523. **self.attn_state.get_graph_input_buffers(attn_metadata),
  1524. **kwargs,
  1525. }
  1526. if intermediate_inputs is not None:
  1527. self.input_buffers.update(intermediate_inputs.tensors)
  1528. if get_pp_group().is_last_rank:
  1529. self.output_buffers = {
  1530. "hidden_states": hidden_or_intermediate_states
  1531. }
  1532. else:
  1533. self.output_buffers = hidden_or_intermediate_states
  1534. return hidden_or_intermediate_states
  1535. def forward(
  1536. self,
  1537. input_ids: torch.Tensor,
  1538. positions: torch.Tensor,
  1539. kv_caches: List[torch.Tensor],
  1540. attn_metadata: AttentionMetadata,
  1541. intermediate_tensors: Optional[IntermediateTensors],
  1542. **kwargs,
  1543. ) -> torch.Tensor:
  1544. # KV caches are fixed tensors, so we don't need to copy them.
  1545. del kv_caches
  1546. # Copy the input tensors to the input buffers.
  1547. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1548. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1549. if self.backend_name != "No attention":
  1550. self.input_buffers["slot_mapping"].copy_(
  1551. attn_metadata.slot_mapping, non_blocking=True)
  1552. self.attn_state.prepare_graph_input_buffers(self.input_buffers,
  1553. attn_metadata)
  1554. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1555. self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
  1556. **kwargs)
  1557. if "previous_hidden_states" in self.input_buffers:
  1558. self.input_buffers["previous_hidden_states"].copy_(
  1559. kwargs["previous_hidden_states"], non_blocking=True)
  1560. if intermediate_tensors is not None:
  1561. for key in intermediate_tensors.tensors:
  1562. self.input_buffers[key].copy_(intermediate_tensors[key],
  1563. non_blocking=True)
  1564. # Run the graph.
  1565. self.graph.replay()
  1566. # Return the output tensor.
  1567. if get_pp_group().is_last_rank:
  1568. return self.output_buffers["hidden_states"]
  1569. return self.output_buffers
  1570. def __call__(self, *args, **kwargs):
  1571. return self.forward(*args, **kwargs)
  1572. def _get_graph_batch_size(batch_size: int) -> int:
  1573. """Returns the padded batch size given actual batch size.
  1574. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1575. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1576. """
  1577. if batch_size <= 2:
  1578. return batch_size
  1579. elif batch_size <= 4:
  1580. return 4
  1581. else:
  1582. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1583. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  1584. def _get_max_graph_batch_size(max_num_seqs: int) -> int:
  1585. """
  1586. max_num_seqs: Maximum number of sequences in a batch.
  1587. _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
  1588. pad the max_num_seqs if necessary by calling _get_graph_batch_size,
  1589. which will deal with some edge cases like 1, 2, 4.
  1590. if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
  1591. if not, it means the padded size is larger than the largest size in
  1592. _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
  1593. """
  1594. padded_size = _get_graph_batch_size(max_num_seqs)
  1595. if padded_size in _BATCH_SIZES_TO_CAPTURE:
  1596. return padded_size
  1597. assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
  1598. return _BATCH_SIZES_TO_CAPTURE[-1]