model_runner.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860
  1. import dataclasses
  2. import gc
  3. import inspect
  4. import itertools
  5. import time
  6. import warnings
  7. import weakref
  8. from dataclasses import dataclass
  9. from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
  10. Tuple, Type, TypeVar, Union)
  11. import numpy as np
  12. import torch
  13. import torch.distributed
  14. import torch.nn as nn
  15. from loguru import logger
  16. import aphrodite.common.envs as envs
  17. from aphrodite.attention import AttentionMetadata, get_attn_backend
  18. from aphrodite.attention.backends.abstract import AttentionState
  19. from aphrodite.attention.backends.utils import CommonAttentionState
  20. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  21. LoRAConfig, ModelConfig, ParallelConfig,
  22. PromptAdapterConfig, SchedulerConfig)
  23. from aphrodite.common.sampling_params import SamplingParams
  24. from aphrodite.common.sequence import (IntermediateTensors,
  25. SequenceGroupMetadata)
  26. from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
  27. async_tensor_h2d, flatten_2d_lists, is_hip,
  28. is_pin_memory_available, supports_dynamo)
  29. from aphrodite.distributed import get_pp_group
  30. from aphrodite.distributed.parallel_state import (
  31. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
  32. graph_capture)
  33. from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
  34. from aphrodite.lora.layers import LoRAMapping
  35. from aphrodite.lora.request import LoRARequest
  36. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  37. from aphrodite.modeling.layers.rotary_embedding import MRotaryEmbedding
  38. from aphrodite.modeling.layers.sampler import SamplerOutput
  39. from aphrodite.modeling.model_loader import get_model
  40. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  41. from aphrodite.modeling.models.interfaces import (supports_lora,
  42. supports_multimodal)
  43. from aphrodite.modeling.models.utils import set_cpu_offload_max_bytes
  44. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  45. SamplingMetadataCache)
  46. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  47. MultiModalInputs, MultiModalRegistry)
  48. from aphrodite.processing.scheduler import SchedulerOutputs
  49. from aphrodite.prompt_adapter.layers import PromptAdapterMapping
  50. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  51. from aphrodite.prompt_adapter.worker_manager import (
  52. LRUCacheWorkerPromptAdapterManager)
  53. from aphrodite.worker.model_runner_base import (
  54. ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
  55. _add_attn_metadata_broadcastable_dict,
  56. _add_sampling_metadata_broadcastable_dict,
  57. _init_attn_metadata_from_tensor_dict,
  58. _init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
  59. if TYPE_CHECKING:
  60. from aphrodite.attention.backends.abstract import AttentionBackend
  61. LORA_WARMUP_RANK = 8
  62. _BATCH_SIZE_ALIGNMENT = 8
  63. # all the token sizes that **can** be captured by cudagraph.
  64. # they can be arbitrarily large.
  65. # currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
  66. # the actual sizes to capture will be determined by the model,
  67. # depending on the model's max_num_seqs.
  68. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  69. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  70. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
  71. ]
  72. _NUM_WARMUP_ITERS = 2
  73. TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
  74. # For now, bump up cache limits for recompilations during CUDA graph warmups.
  75. torch._dynamo.config.cache_size_limit = 128
  76. torch._dynamo.config.accumulated_cache_size_limit = 128
  77. @dataclass(frozen=True)
  78. class ModelInputForGPU(ModelRunnerInputBase):
  79. """
  80. This base class contains metadata needed for the base model forward pass
  81. but not metadata for possible additional steps, e.g., sampling. Model
  82. runners that run additional steps should subclass this method to add
  83. additional fields.
  84. """
  85. input_tokens: Optional[torch.Tensor] = None
  86. input_positions: Optional[torch.Tensor] = None
  87. seq_lens: Optional[List[int]] = None
  88. query_lens: Optional[List[int]] = None
  89. lora_mapping: Optional["LoRAMapping"] = None
  90. lora_requests: Optional[Set[LoRARequest]] = None
  91. attn_metadata: Optional["AttentionMetadata"] = None
  92. prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
  93. prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
  94. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  95. request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
  96. finished_requests_ids: Optional[List[str]] = None
  97. virtual_engine: int = 0
  98. async_callback: Optional[Callable] = None
  99. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
  100. scheduler_outputs: Optional[SchedulerOutputs] = None
  101. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  102. tensor_dict = {
  103. "input_tokens": self.input_tokens,
  104. "input_positions": self.input_positions,
  105. "lora_requests": self.lora_requests,
  106. "lora_mapping": self.lora_mapping,
  107. "multi_modal_kwargs": self.multi_modal_kwargs,
  108. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  109. "prompt_adapter_requests": self.prompt_adapter_requests,
  110. "virtual_engine": self.virtual_engine,
  111. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  112. "finished_requests_ids": self.finished_requests_ids,
  113. }
  114. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  115. return tensor_dict
  116. @classmethod
  117. def from_broadcasted_tensor_dict(
  118. cls: Type[TModelInputForGPU],
  119. tensor_dict: Dict[str, Any],
  120. attn_backend: Optional["AttentionBackend"] = None,
  121. ) -> TModelInputForGPU:
  122. if attn_backend is not None:
  123. tensor_dict = _init_attn_metadata_from_tensor_dict(
  124. attn_backend, tensor_dict)
  125. return cls(**tensor_dict)
  126. @dataclass(frozen=True)
  127. class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
  128. """
  129. Used by the ModelRunner.
  130. """
  131. sampling_metadata: Optional["SamplingMetadata"] = None
  132. # Used for speculative decoding. We do not broadcast it because it is only
  133. # used by the driver worker.
  134. is_prompt: Optional[bool] = None
  135. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  136. tensor_dict = {
  137. "input_tokens": self.input_tokens,
  138. "input_positions": self.input_positions,
  139. "lora_requests": self.lora_requests,
  140. "lora_mapping": self.lora_mapping,
  141. "multi_modal_kwargs": self.multi_modal_kwargs,
  142. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  143. "prompt_adapter_requests": self.prompt_adapter_requests,
  144. "virtual_engine": self.virtual_engine,
  145. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  146. "finished_requests_ids": self.finished_requests_ids,
  147. }
  148. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  149. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  150. self.sampling_metadata)
  151. return tensor_dict
  152. @classmethod
  153. def from_broadcasted_tensor_dict(
  154. cls,
  155. tensor_dict: Dict[str, Any],
  156. attn_backend: Optional["AttentionBackend"] = None,
  157. ) -> "ModelInputForGPUWithSamplingMetadata":
  158. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  159. if attn_backend is not None:
  160. tensor_dict = _init_attn_metadata_from_tensor_dict(
  161. attn_backend, tensor_dict)
  162. return cls(**tensor_dict)
  163. class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
  164. """Build ModelInputForGPU from SequenceGroupMetadata."""
  165. # NOTE: ideally, we dould be using a dataclass(kw_only=True)
  166. # here, so that this can be subclassed easily, but kw_only
  167. # is not supported in python<3.10.
  168. class InterDataForSeqGroup:
  169. """Intermediate data for the current sequence group."""
  170. def simple_reinit(self):
  171. self.input_tokens[0].clear() # type: ignore
  172. self.input_positions[0].clear() # type: ignore
  173. self.mrope_input_positions = None # type: ignore
  174. self.seq_lens[0] = 0 # type: ignore
  175. self.orig_seq_lens[0] = 0 # type: ignore
  176. self.query_lens[0] = 0 # type: ignore
  177. self.context_lens[0] = 0 # type: ignore
  178. self.curr_sliding_window_blocks[0] = 0 # type: ignore
  179. self.lora_index_mapping.clear() # type: ignore
  180. self.lora_prompt_mapping.clear() # type: ignore
  181. self.lora_requests.clear() # type: ignore
  182. self.prompt_adapter_index_mapping.clear() # type: ignore
  183. self.prompt_adapter_prompt_mapping.clear() # type: ignore
  184. def __init__(
  185. self,
  186. *,
  187. # From sequence group metadata.
  188. request_id: str,
  189. seq_ids: List[int],
  190. is_prompt: bool,
  191. block_tables: Optional[Dict[int, List[int]]],
  192. computed_block_nums: List[int],
  193. n_seqs: int = 0,
  194. # Input tokens and positions.
  195. input_tokens: Optional[List[List[int]]] = None,
  196. input_positions: Optional[List[List[int]]] = None,
  197. mrope_input_positions: Optional[List[List[List[int]]]] = None,
  198. # The sequence length (may be capped to the sliding window).
  199. seq_lens: Optional[List[int]] = None,
  200. # The original sequence length (before applying sliding window).
  201. # This is used to compute slot mapping.
  202. orig_seq_lens: Optional[List[int]] = None,
  203. # The query length.
  204. query_lens: Optional[List[int]] = None,
  205. # The number of tokens that are already computed.
  206. context_lens: Optional[List[int]] = None,
  207. # The current sliding window block.
  208. curr_sliding_window_blocks: Optional[List[int]] = None,
  209. # LoRA inputs.
  210. lora_index_mapping: Optional[List[List[int]]] = None,
  211. lora_prompt_mapping: Optional[List[List[int]]] = None,
  212. lora_requests: Optional[Set[LoRARequest]] = None,
  213. # Prompt adapter inputs.
  214. prompt_adapter_index_mapping: Optional[List[int]] = None,
  215. prompt_adapter_prompt_mapping: Optional[List[int]] = None,
  216. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  217. # Multi-modal inputs.
  218. multi_modal_inputs: Optional[MultiModalInputs] = None,
  219. # Whether the prefix cache is hit (prefill only).
  220. prefix_cache_hit: bool = False,
  221. reinit: bool = False,
  222. reinit_use_defaults: bool = False,
  223. encoder_seq_len: int = 0,
  224. ):
  225. if reinit:
  226. assert len(self.seq_ids) == len(seq_ids) # type: ignore
  227. for i, seq_id in enumerate(seq_ids):
  228. self.seq_ids[i] = seq_id # type: ignore
  229. else:
  230. self.seq_ids = seq_ids
  231. self.request_id = request_id
  232. self.is_prompt = is_prompt
  233. self.block_tables = block_tables
  234. self.computed_block_nums = computed_block_nums
  235. self.n_seqs = n_seqs
  236. self.encoder_seq_len = encoder_seq_len
  237. if reinit:
  238. if len(self.seq_ids) == 1 and reinit_use_defaults:
  239. self.simple_reinit()
  240. else:
  241. if input_tokens:
  242. self.input_tokens = input_tokens
  243. else:
  244. for seq_id in range(len(self.seq_ids)):
  245. self.input_tokens[seq_id].clear()
  246. if input_positions:
  247. self.input_positions = input_positions
  248. else:
  249. for seq_id in range(len(self.seq_ids)):
  250. self.input_positions[seq_id].clear()
  251. self.mrope_input_positions = None
  252. if seq_lens:
  253. self.seq_lens = seq_lens
  254. else:
  255. for seq_id in range(len(self.seq_ids)):
  256. self.seq_lens[seq_id] = 0
  257. if orig_seq_lens:
  258. self.orig_seq_lens = orig_seq_lens
  259. else:
  260. for seq_id in range(len(self.seq_ids)):
  261. self.orig_seq_lens[seq_id] = 0
  262. if query_lens:
  263. self.query_lens = query_lens
  264. else:
  265. for seq_id in range(len(self.seq_ids)):
  266. self.query_lens[seq_id] = 0
  267. if context_lens:
  268. self.context_lens = context_lens
  269. else:
  270. for seq_id in range(len(self.seq_ids)):
  271. self.context_lens[seq_id] = 0
  272. if curr_sliding_window_blocks:
  273. self.curr_sliding_window_blocks = \
  274. curr_sliding_window_blocks
  275. else:
  276. for seq_id in range(len(self.seq_ids)):
  277. self.curr_sliding_window_blocks[seq_id] = 0
  278. if lora_index_mapping:
  279. self.lora_index_mapping = lora_index_mapping
  280. else:
  281. self.lora_index_mapping.clear()
  282. if lora_prompt_mapping:
  283. self.lora_prompt_mapping = lora_prompt_mapping
  284. else:
  285. self.lora_prompt_mapping.clear()
  286. if lora_requests:
  287. self.lora_requests = lora_requests
  288. else:
  289. self.lora_requests.clear()
  290. if prompt_adapter_index_mapping:
  291. self.prompt_adapter_index_mapping = \
  292. prompt_adapter_index_mapping
  293. else:
  294. self.prompt_adapter_index_mapping.clear()
  295. if prompt_adapter_prompt_mapping:
  296. self.prompt_adapter_prompt_mapping = \
  297. prompt_adapter_prompt_mapping
  298. else:
  299. self.prompt_adapter_prompt_mapping.clear()
  300. else:
  301. self.input_tokens = input_tokens or []
  302. self.input_positions = input_positions or []
  303. self.mrope_input_positions = mrope_input_positions or None
  304. self.seq_lens = seq_lens or []
  305. self.orig_seq_lens = orig_seq_lens or []
  306. self.query_lens = query_lens or []
  307. self.context_lens = context_lens or []
  308. self.curr_sliding_window_blocks = \
  309. curr_sliding_window_blocks or []
  310. self.lora_index_mapping = lora_index_mapping or []
  311. self.lora_prompt_mapping = lora_prompt_mapping or []
  312. self.lora_requests = lora_requests or set()
  313. self.prompt_adapter_index_mapping = (
  314. prompt_adapter_index_mapping or [])
  315. self.prompt_adapter_prompt_mapping = (
  316. prompt_adapter_prompt_mapping or [])
  317. self.prompt_adapter_request = prompt_adapter_request
  318. self.multi_modal_inputs = multi_modal_inputs
  319. self.prefix_cache_hit = prefix_cache_hit
  320. self.n_seqs = len(self.seq_ids)
  321. if not reinit:
  322. self.__post_init__()
  323. def __post_init__(self):
  324. self.n_seqs = len(self.seq_ids)
  325. self.input_tokens = [[] for _ in range(self.n_seqs)]
  326. self.input_positions = [[] for _ in range(self.n_seqs)]
  327. self.mrope_input_positions = None
  328. self.seq_lens = [0] * self.n_seqs
  329. self.orig_seq_lens = [0] * self.n_seqs
  330. self.query_lens = [0] * self.n_seqs
  331. self.context_lens = [0] * self.n_seqs
  332. self.curr_sliding_window_blocks = [0] * self.n_seqs
  333. self.lora_index_mapping = []
  334. self.lora_prompt_mapping = []
  335. def gen_inter_data_builder(self, num_seqs: int):
  336. return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
  337. request_id="",
  338. seq_ids=[0] * num_seqs,
  339. is_prompt=True,
  340. block_tables=None,
  341. computed_block_nums=[])
  342. def init_cached_inter_data(self, *args, **kwargs):
  343. assert len(args) == 0
  344. assert "seq_ids" in kwargs
  345. seq_ids = kwargs["seq_ids"]
  346. num_seqs = len(seq_ids)
  347. # The inter-data cache is per model_runner
  348. inter_data_cache = self.runner.inter_data_cache
  349. if num_seqs not in inter_data_cache:
  350. inter_data_cache[num_seqs] = PyObjectCache(
  351. self.gen_inter_data_builder(num_seqs))
  352. obj = inter_data_cache[num_seqs].get_object()
  353. obj.__init__(*args, **kwargs)
  354. return obj
  355. def reset_cached_inter_data(self):
  356. for cache in self.runner.inter_data_cache.values():
  357. cache.reset()
  358. def __init__(self,
  359. runner: "GPUModelRunnerBase",
  360. finished_requests_ids: Optional[List[str]] = None):
  361. super().__init__()
  362. # Compute functions for each sequence in a sequence group.
  363. # WARNING: The order of the functions matters!
  364. self.per_seq_compute_fns = [
  365. self._compute_lens,
  366. self._compute_for_prefix_cache_hit,
  367. self._compute_for_sliding_window,
  368. self._compute_lora_input,
  369. ]
  370. # Compute functions for each sequence group.
  371. # WARNING: The order of the functions matters!
  372. self.per_seq_group_compute_fns = [
  373. self._compute_prompt_adapter_input,
  374. self._compute_multi_modal_input,
  375. ]
  376. self.runner = runner
  377. self.model_input_cls = self.runner._model_input_cls
  378. self.attn_backend = self.runner.attn_backend
  379. self.scheduler_config = self.runner.scheduler_config
  380. self.sliding_window = self.runner.sliding_window
  381. self.block_size = self.runner.block_size
  382. self.enable_lora = self.runner.lora_config is not None
  383. self.enable_prompt_adapter = (self.runner.prompt_adapter_config
  384. is not None)
  385. self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
  386. self.finished_requests_ids = finished_requests_ids
  387. self.decode_only = True
  388. # Intermediate data (data in CPU before going to GPU) for
  389. # the current sequence group.
  390. self.inter_data_list: List[
  391. ModelInputForGPUBuilder.InterDataForSeqGroup] = []
  392. # Attention metadata inputs.
  393. self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
  394. weakref.proxy(self))
  395. # Engine/Model configurations.
  396. self.chunked_prefill_enabled = (
  397. self.scheduler_config is not None
  398. and self.scheduler_config.chunked_prefill_enabled)
  399. if self.sliding_window is not None:
  400. self.sliding_window_blocks = (
  401. self.sliding_window + self.block_size - 1) // self.block_size
  402. self.block_aligned_sliding_window = \
  403. self.sliding_window_blocks * self.block_size
  404. def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
  405. seq_group_metadata: SequenceGroupMetadata):
  406. """Compute context length, sequence length and tokens
  407. for the given sequence data.
  408. """
  409. seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
  410. token_chunk_size = seq_group_metadata.token_chunk_size
  411. # Compute context length (the number of tokens that are
  412. # already computed) and sequence length (total number of tokens).
  413. seq_len = seq_data.get_len()
  414. if inter_data.is_prompt:
  415. context_len = seq_data.get_num_computed_tokens()
  416. else:
  417. # get_num_computed_tokens is incorrect for spec decoding.
  418. # So, we should have a special logic here.
  419. # TODO: Fix it.
  420. context_len = seq_len - 1
  421. seq_len = min(seq_len, context_len + token_chunk_size)
  422. # Compute tokens.
  423. if inter_data.is_prompt:
  424. tokens = seq_data.get_token_ids()
  425. if context_len != 0 or seq_len < len(tokens):
  426. tokens = tokens[context_len:seq_len]
  427. else:
  428. # Optimization. get_token_ids requires the entire copy of
  429. # tokens.
  430. tokens = seq_data.get_last_token_id()
  431. inter_data.seq_lens[seq_idx] = seq_len
  432. inter_data.orig_seq_lens[seq_idx] = seq_len
  433. inter_data.context_lens[seq_idx] = context_len
  434. if isinstance(tokens, list):
  435. inter_data.input_tokens[seq_idx].extend(tokens)
  436. else:
  437. inter_data.input_tokens[seq_idx].append(tokens)
  438. if (seq_len - context_len) == 1:
  439. inter_data.input_positions[seq_idx].append(seq_len - 1)
  440. else:
  441. inter_data.input_positions[seq_idx].extend(
  442. range(context_len, seq_len))
  443. inter_data.query_lens[
  444. seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
  445. if seq_data.mrope_position_delta is not None:
  446. if inter_data.mrope_input_positions is None:
  447. inter_data.mrope_input_positions = [None] * inter_data.n_seqs
  448. inter_data.mrope_input_positions[
  449. seq_idx] = MRotaryEmbedding.get_next_input_positions(
  450. seq_data.mrope_position_delta,
  451. context_len,
  452. seq_len,
  453. )
  454. def _compute_for_prefix_cache_hit(
  455. self, inter_data: InterDataForSeqGroup, seq_idx: int,
  456. seq_group_metadata: SequenceGroupMetadata):
  457. """Check if hit prefix cache (i.e., some blocks are already computed).
  458. If hit, update input tokens and positions to only compute the
  459. remaining blocks.
  460. """
  461. computed_block_nums = inter_data.computed_block_nums
  462. # Note that prefix caching does not support sliding window.
  463. prefix_cache_hit = (computed_block_nums is not None
  464. and len(computed_block_nums) > 0
  465. and self.sliding_window is None
  466. and inter_data.is_prompt)
  467. inter_data.prefix_cache_hit = prefix_cache_hit
  468. if not prefix_cache_hit:
  469. return
  470. assert computed_block_nums is not None
  471. # The cache hit prompt tokens in this sequence. Note that
  472. # this may be larger than the sequence length if chunked
  473. # prefill is enabled.
  474. prefix_cache_len = len(computed_block_nums) * self.block_size
  475. # The number of so far computed prompt tokens in this sequence.
  476. context_len = inter_data.context_lens[seq_idx]
  477. # The total number of prompt tokens in this sequence.
  478. # When chunked prefill is enabled, this is the token number of
  479. # computed chunks + current chunk.
  480. seq_len = inter_data.seq_lens[seq_idx]
  481. if prefix_cache_len <= context_len:
  482. # We already passed the cache hit region,
  483. # so do normal computation.
  484. pass
  485. elif context_len < prefix_cache_len < seq_len:
  486. # Partial hit. Compute the missing part.
  487. uncomputed_start = prefix_cache_len - context_len
  488. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  489. seq_idx][uncomputed_start:]
  490. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  491. seq_idx][uncomputed_start:]
  492. context_len = prefix_cache_len
  493. inter_data.context_lens[seq_idx] = context_len
  494. inter_data.query_lens[
  495. seq_idx] = inter_data.seq_lens[seq_idx] - context_len
  496. elif seq_len <= prefix_cache_len:
  497. # Full hit. Only compute the last token to avoid
  498. # erroneous behavior. FIXME: Ideally we should directly
  499. # mark all tokens as computed in the scheduler and do not
  500. # schedule this sequence, so this case should not happen.
  501. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  502. seq_idx][-1:]
  503. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  504. seq_idx][-1:]
  505. inter_data.query_lens[seq_idx] = 1
  506. inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
  507. def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
  508. seq_idx: int,
  509. seq_group_metadata: SequenceGroupMetadata):
  510. """Update seq_len and curr_sliding_window_block for the given
  511. sequence data (only required by decoding) if sliding window is enabled.
  512. """
  513. curr_sliding_window_block = 0
  514. sliding_seq_len = inter_data.seq_lens[seq_idx]
  515. if not inter_data.is_prompt and self.sliding_window is not None:
  516. # TODO: This is a hack to make sliding window work with
  517. # paged attn. We can remove it if we make paged attn kernel
  518. # to properly handle slinding window attn.
  519. curr_sliding_window_block = self.sliding_window_blocks
  520. if self.scheduler_config.use_v2_block_manager:
  521. # number of elements in last block
  522. suff_len = inter_data.seq_lens[seq_idx] % self.block_size
  523. sliding_seq_len = min(
  524. inter_data.seq_lens[seq_idx],
  525. self.block_aligned_sliding_window + suff_len)
  526. if suff_len > 0:
  527. curr_sliding_window_block += 1
  528. else:
  529. sliding_seq_len = min(inter_data.seq_lens[seq_idx],
  530. self.sliding_window)
  531. inter_data.curr_sliding_window_blocks[
  532. seq_idx] = curr_sliding_window_block
  533. inter_data.seq_lens[seq_idx] = sliding_seq_len
  534. def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
  535. seq_idx: int,
  536. seq_group_metadata: SequenceGroupMetadata):
  537. """If LoRA is enabled, compute LoRA index and prompt mapping."""
  538. if not self.enable_lora:
  539. return
  540. lora_id = seq_group_metadata.lora_int_id
  541. if lora_id > 0:
  542. inter_data.lora_requests.add(seq_group_metadata.lora_request)
  543. query_len = inter_data.query_lens[seq_idx]
  544. inter_data.lora_index_mapping.append([lora_id] * query_len)
  545. sampling_params = seq_group_metadata.sampling_params
  546. if sampling_params and sampling_params.prompt_logprobs is not None:
  547. inter_data.lora_prompt_mapping.append([lora_id] * query_len)
  548. elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
  549. inter_data.lora_prompt_mapping.append([lora_id])
  550. else:
  551. inter_data.lora_prompt_mapping.append([])
  552. def _compute_prompt_adapter_input(
  553. self, inter_data: InterDataForSeqGroup,
  554. seq_group_metadata: SequenceGroupMetadata):
  555. """If prompt adapter is enabled, compute index and prompt mapping.
  556. """
  557. # Note that when is_prompt=True, we expect only one sequence
  558. # in the group.
  559. if not self.enable_prompt_adapter:
  560. return
  561. prompt_adapter_id = seq_group_metadata.prompt_adapter_id
  562. if prompt_adapter_id <= 0 or not inter_data.is_prompt:
  563. return
  564. # We expect only one sequence in the group when is_prompt=True.
  565. assert inter_data.n_seqs == 1
  566. query_len = inter_data.query_lens[0]
  567. inter_data.prompt_adapter_request = (
  568. seq_group_metadata.prompt_adapter_request)
  569. num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
  570. inter_data.prompt_adapter_index_mapping = [
  571. prompt_adapter_id
  572. ] * num_tokens + [0] * (query_len - num_tokens)
  573. inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
  574. query_len if seq_group_metadata.sampling_params
  575. and seq_group_metadata.sampling_params.prompt_logprobs else 1)
  576. def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
  577. seq_group_metadata: SequenceGroupMetadata):
  578. """If multi-modal data is given, add it to the input."""
  579. mm_data = seq_group_metadata.multi_modal_data
  580. if not mm_data:
  581. return
  582. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  583. inter_data.multi_modal_inputs = mm_kwargs
  584. # special processing for mrope position deltas.
  585. if self.runner.model_is_mrope:
  586. image_grid_thw = mm_kwargs.get("image_grid_thw", None)
  587. video_grid_thw = mm_kwargs.get("video_grid_thw", None)
  588. assert image_grid_thw is not None or video_grid_thw is not None, (
  589. "mrope embedding type requires multi-modal input mapper "
  590. "returns 'image_grid_thw' or 'video_grid_thw'.")
  591. hf_config = self.runner.model_config.hf_config
  592. inter_data.mrope_input_positions = [None] * inter_data.n_seqs
  593. for seq_idx in range(inter_data.n_seqs):
  594. seq_data = seq_group_metadata.seq_data[
  595. inter_data.seq_ids[seq_idx]]
  596. token_ids = seq_data.get_token_ids()
  597. mrope_input_positions, mrope_position_delta = \
  598. MRotaryEmbedding.get_input_positions(
  599. token_ids,
  600. image_grid_thw=image_grid_thw,
  601. video_grid_thw=video_grid_thw,
  602. image_token_id=hf_config.image_token_id,
  603. video_token_id=hf_config.video_token_id,
  604. vision_start_token_id=hf_config.vision_start_token_id,
  605. vision_end_token_id=hf_config.vision_end_token_id,
  606. spatial_merge_size=hf_config.vision_config.
  607. spatial_merge_size,
  608. context_len=inter_data.context_lens[seq_idx],
  609. )
  610. seq_data.mrope_position_delta = mrope_position_delta
  611. inter_data.mrope_input_positions[
  612. seq_idx] = mrope_input_positions
  613. def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
  614. """Add a sequence group to the builder."""
  615. seq_ids = seq_group_metadata.seq_data.keys()
  616. n_seqs = len(seq_ids)
  617. is_prompt = seq_group_metadata.is_prompt
  618. if is_prompt:
  619. assert n_seqs == 1
  620. self.decode_only = False
  621. encoder_seq_len = 0
  622. if self.runner.model_config.is_encoder_decoder_model:
  623. encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
  624. inter_data = self.init_cached_inter_data(
  625. request_id=seq_group_metadata.request_id,
  626. seq_ids=seq_ids,
  627. is_prompt=is_prompt,
  628. block_tables=seq_group_metadata.block_tables,
  629. computed_block_nums=seq_group_metadata.computed_block_nums,
  630. reinit=True,
  631. reinit_use_defaults=True,
  632. encoder_seq_len=encoder_seq_len)
  633. self.inter_data_list.append(inter_data)
  634. for seq_idx in range(n_seqs):
  635. for per_seq_fn in self.per_seq_compute_fns:
  636. per_seq_fn(inter_data, seq_idx, seq_group_metadata)
  637. for per_seq_group_fn in self.per_seq_group_compute_fns:
  638. per_seq_group_fn(inter_data, seq_group_metadata)
  639. def _use_captured_graph(self,
  640. batch_size: int,
  641. max_decode_seq_len: int,
  642. max_encoder_seq_len: int = 0) -> bool:
  643. return (self.decode_only and not self.runner.model_config.enforce_eager
  644. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  645. and max_decode_seq_len <= self.runner.max_seq_len_to_capture
  646. and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
  647. and batch_size <= self.runner.max_batchsize_to_capture)
  648. def build(self) -> ModelInputForGPU:
  649. """Finalize the builder intermediate data and
  650. create on-device tensors.
  651. """
  652. # Combine and flatten intermediate data.
  653. input_tokens = []
  654. for inter_data in self.inter_data_list:
  655. for cur_input_tokens in inter_data.input_tokens:
  656. input_tokens.extend(cur_input_tokens)
  657. if not input_tokens:
  658. # This may happen when all prefill requests hit
  659. # prefix caching and there is no decode request.
  660. return self.model_input_cls()
  661. mrope_input_positions: Optional[List[List[int]]] = None
  662. if any(inter_data.mrope_input_positions is not None
  663. for inter_data in self.inter_data_list):
  664. mrope_input_positions = [[] for _ in range(3)]
  665. for idx in range(3):
  666. for inter_data in self.inter_data_list:
  667. msections = inter_data.mrope_input_positions
  668. if msections is None:
  669. for _seq_input_positions in inter_data.input_positions:
  670. mrope_input_positions[idx].extend(
  671. _seq_input_positions)
  672. else:
  673. for _seq_mrope_input_positions in msections:
  674. mrope_input_positions[idx].extend(
  675. _seq_mrope_input_positions[idx])
  676. input_positions = None
  677. else:
  678. input_positions = []
  679. for inter_data in self.inter_data_list:
  680. for cur_input_positions in inter_data.input_positions:
  681. input_positions.extend(cur_input_positions)
  682. seq_lens = []
  683. query_lens = []
  684. max_decode_seq_len = 0
  685. max_encoder_seq_len = 0
  686. for inter_data in self.inter_data_list:
  687. seq_lens.extend(inter_data.seq_lens)
  688. query_lens.extend(inter_data.query_lens)
  689. if not inter_data.is_prompt:
  690. max_decode_seq_len = max(max_decode_seq_len,
  691. max(inter_data.seq_lens))
  692. if self.runner.model_config.is_encoder_decoder_model:
  693. max_encoder_seq_len = max(max_encoder_seq_len,
  694. inter_data.encoder_seq_len)
  695. # Mapping from request IDs to sequence IDs. Used for Jamba models
  696. # that manages the cache by itself.
  697. request_ids_to_seq_ids = {
  698. data.request_id: data.seq_ids
  699. for data in self.inter_data_list
  700. }
  701. batch_size = len(input_tokens)
  702. use_captured_graph = self._use_captured_graph(
  703. batch_size,
  704. max_decode_seq_len,
  705. max_encoder_seq_len=max_encoder_seq_len)
  706. # If cuda graph can be used, pad tensors accordingly.
  707. # See `capture_model` API for more details.
  708. # Aphrodite uses cuda graph only for decoding requests.
  709. cuda_graph_pad_size = -1
  710. if use_captured_graph:
  711. graph_batch_size = _get_graph_batch_size(batch_size)
  712. assert graph_batch_size >= batch_size
  713. cuda_graph_pad_size = graph_batch_size - batch_size
  714. batch_size = graph_batch_size
  715. # Tokens and positions.
  716. if cuda_graph_pad_size:
  717. input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
  718. assert self.runner.device is not None
  719. input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
  720. self.runner.device,
  721. self.runner.pin_memory)
  722. if mrope_input_positions is not None:
  723. for idx in range(3):
  724. mrope_input_positions[idx].extend(
  725. itertools.repeat(0, cuda_graph_pad_size))
  726. input_positions_tensor = async_tensor_h2d(mrope_input_positions,
  727. torch.long,
  728. self.runner.device,
  729. self.runner.pin_memory)
  730. else:
  731. input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
  732. input_positions_tensor = async_tensor_h2d(input_positions,
  733. torch.long,
  734. self.runner.device,
  735. self.runner.pin_memory)
  736. # Sequence and query lengths.
  737. if cuda_graph_pad_size:
  738. seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
  739. # Attention metadata.
  740. attn_metadata = self.attn_metadata_builder.build(
  741. seq_lens, query_lens, cuda_graph_pad_size, batch_size)
  742. # LoRA data.
  743. lora_requests = set()
  744. lora_mapping = None
  745. if self.enable_lora:
  746. lora_requests = set(r for data in self.inter_data_list
  747. for r in data.lora_requests)
  748. lora_index_mapping = flatten_2d_lists([
  749. flatten_2d_lists(inter_data.lora_index_mapping)
  750. for inter_data in self.inter_data_list
  751. ])
  752. if cuda_graph_pad_size:
  753. lora_index_mapping.extend(
  754. itertools.repeat(0, cuda_graph_pad_size))
  755. lora_prompt_mapping = flatten_2d_lists([
  756. flatten_2d_lists(inter_data.lora_prompt_mapping)
  757. for inter_data in self.inter_data_list
  758. ])
  759. lora_mapping = LoRAMapping(
  760. **dict(index_mapping=lora_index_mapping,
  761. prompt_mapping=lora_prompt_mapping,
  762. is_prefill=not self.decode_only))
  763. # Prompt adapter data.
  764. prompt_adapter_requests: Set[PromptAdapterRequest] = set()
  765. prompt_adapter_mapping = None
  766. if self.enable_prompt_adapter:
  767. prompt_adapter_requests = set(
  768. data.prompt_adapter_request for data in self.inter_data_list
  769. if data.prompt_adapter_request is not None)
  770. prompt_adapter_index_mapping = flatten_2d_lists([
  771. inter_data.prompt_adapter_index_mapping
  772. for inter_data in self.inter_data_list
  773. ])
  774. if cuda_graph_pad_size:
  775. prompt_adapter_index_mapping.extend(
  776. itertools.repeat(0, cuda_graph_pad_size))
  777. prompt_adapter_prompt_mapping = flatten_2d_lists([
  778. inter_data.prompt_adapter_prompt_mapping
  779. for inter_data in self.inter_data_list
  780. ])
  781. prompt_adapter_mapping = PromptAdapterMapping(
  782. prompt_adapter_index_mapping,
  783. prompt_adapter_prompt_mapping,
  784. )
  785. # Multi-modal data.
  786. multi_modal_inputs_list = [
  787. data.multi_modal_inputs for data in self.inter_data_list
  788. if data.multi_modal_inputs is not None
  789. ]
  790. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  791. return self.model_input_cls(
  792. input_tokens=input_tokens_tensor,
  793. input_positions=input_positions_tensor,
  794. attn_metadata=attn_metadata,
  795. seq_lens=seq_lens,
  796. query_lens=query_lens,
  797. lora_mapping=lora_mapping,
  798. lora_requests=lora_requests,
  799. multi_modal_kwargs=multi_modal_kwargs,
  800. request_ids_to_seq_ids=request_ids_to_seq_ids,
  801. finished_requests_ids=self.finished_requests_ids,
  802. prompt_adapter_mapping=prompt_adapter_mapping,
  803. prompt_adapter_requests=prompt_adapter_requests)
  804. class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
  805. """
  806. Helper class for shared methods between GPU model runners.
  807. """
  808. _model_input_cls: Type[TModelInputForGPU]
  809. _builder_cls: Type[ModelInputForGPUBuilder]
  810. def __init__(
  811. self,
  812. model_config: ModelConfig,
  813. parallel_config: ParallelConfig,
  814. scheduler_config: SchedulerConfig,
  815. device_config: DeviceConfig,
  816. cache_config: CacheConfig,
  817. load_config: LoadConfig,
  818. lora_config: Optional[LoRAConfig],
  819. kv_cache_dtype: Optional[str] = "auto",
  820. is_driver_worker: bool = False,
  821. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  822. return_hidden_states: bool = False,
  823. input_registry: InputRegistry = INPUT_REGISTRY,
  824. mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
  825. tp_rank: int = 0,
  826. ):
  827. self.model_config = model_config
  828. self.parallel_config = parallel_config
  829. self.scheduler_config = scheduler_config
  830. self.device_config = device_config
  831. self.cache_config = cache_config
  832. self.lora_config = lora_config
  833. self.load_config = load_config
  834. self.is_driver_worker = is_driver_worker
  835. self.prompt_adapter_config = prompt_adapter_config
  836. self.return_hidden_states = return_hidden_states
  837. self.device = self.device_config.device
  838. self.pin_memory = is_pin_memory_available()
  839. self.tp_rank = tp_rank
  840. self.kv_cache_dtype = kv_cache_dtype
  841. self.sliding_window = model_config.get_sliding_window()
  842. self.block_size = cache_config.block_size
  843. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  844. self.max_batchsize_to_capture = _get_max_graph_batch_size(
  845. self.scheduler_config.max_num_seqs)
  846. self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
  847. {} for _ in range(self.parallel_config.pipeline_parallel_size)
  848. ]
  849. self.graph_memory_pool: Optional[Tuple[
  850. int, int]] = None # Set during graph capture.
  851. self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
  852. parallel_config)
  853. # When using CUDA graph, the input block tables must be padded to
  854. # max_seq_len_to_capture. However, creating the block table in
  855. # Python can be expensive. To optimize this, we cache the block table
  856. # in numpy and only copy the actual input content at every iteration.
  857. # The shape of the cached block table will be
  858. # (max batch size to capture, max context len to capture / block size).
  859. self.graph_block_tables = np.zeros(
  860. (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
  861. dtype=np.int32)
  862. self.attn_backend = get_attn_backend(
  863. self.model_config.get_head_size(),
  864. self.model_config.get_sliding_window(),
  865. self.model_config.dtype,
  866. self.kv_cache_dtype,
  867. self.block_size,
  868. self.model_config.is_attention_free(),
  869. )
  870. if self.attn_backend:
  871. self.attn_state = self.attn_backend.get_state_cls()(
  872. weakref.proxy(self))
  873. else:
  874. self.attn_state = CommonAttentionState(weakref.proxy(self))
  875. # Multi-modal data support
  876. self.input_registry = input_registry
  877. self.mm_registry = mm_registry
  878. self.multi_modal_input_mapper = mm_registry \
  879. .create_input_mapper(model_config)
  880. self.mm_registry.init_mm_limits_per_prompt(self.model_config)
  881. # Lazy initialization
  882. self.model: nn.Module # Set after load_model
  883. # Set after load_model.
  884. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  885. self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
  886. set_cpu_offload_max_bytes(
  887. int(self.cache_config.cpu_offload_gb * 1024**3))
  888. # Used to cache python objects
  889. self.inter_data_cache: Dict[int, PyObjectCache] = {}
  890. self.sampling_metadata_cache: SamplingMetadataCache = \
  891. SamplingMetadataCache()
  892. def load_model(self) -> None:
  893. tp = get_tensor_model_parallel_world_size()
  894. rank = get_tensor_model_parallel_rank()
  895. if rank == 0:
  896. logger.info(f"Loading model {self.model_config.model}...")
  897. with CudaMemoryProfiler() as m:
  898. # measure the time it takes to load the model
  899. start_time = time.time()
  900. self.model = get_model(model_config=self.model_config,
  901. device_config=self.device_config,
  902. load_config=self.load_config,
  903. lora_config=self.lora_config,
  904. parallel_config=self.parallel_config,
  905. scheduler_config=self.scheduler_config,
  906. cache_config=self.cache_config)
  907. end_time = time.time()
  908. self.model_memory_usage = m.consumed_memory
  909. total_time = end_time - start_time
  910. if tp > 1:
  911. if rank == 0:
  912. logger.info(f"Model loaded in {total_time:.2f} seconds.")
  913. logger.info(
  914. "Total model weights memory usage: "
  915. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  916. else:
  917. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  918. logger.info(
  919. "Total model weights memory usage: "
  920. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  921. if self.lora_config:
  922. assert supports_lora(self.model), "Model does not support LoRA"
  923. assert not supports_multimodal(
  924. self.model
  925. ), "To be tested: multimodal language model with LoRA settings."
  926. self.lora_manager = LRUCacheWorkerLoRAManager(
  927. self.scheduler_config.max_num_seqs,
  928. self.scheduler_config.max_num_batched_tokens,
  929. self.vocab_size,
  930. self.lora_config,
  931. self.device,
  932. self.model.embedding_modules,
  933. self.model.embedding_padding_modules,
  934. max_position_embeddings=self.model.config.
  935. max_position_embeddings,
  936. )
  937. self.model = self.lora_manager.create_lora_manager(self.model)
  938. if self.prompt_adapter_config:
  939. self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
  940. self.scheduler_config.max_num_seqs,
  941. self.scheduler_config.max_num_batched_tokens, self.device,
  942. self.prompt_adapter_config)
  943. self.model = (
  944. self.prompt_adapter_manager.create_prompt_adapter_manager(
  945. self.model))
  946. if self.kv_cache_dtype == "fp8" and is_hip():
  947. # Currently only ROCm accepts kv-cache scaling factors
  948. # via quantization_param_path and this will be deprecated
  949. # in the future.
  950. if self.model_config.quantization_param_path is not None:
  951. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  952. warnings.warn(
  953. "Loading kv cache scaling factor from JSON is "
  954. "deprecated and will be removed. Please include "
  955. "kv cache scaling factors in the model checkpoint.",
  956. FutureWarning,
  957. stacklevel=2)
  958. self.model.load_kv_cache_scales(
  959. self.model_config.quantization_param_path)
  960. logger.info(
  961. "Loaded KV cache scaling factors from ",
  962. f"{self.model_config.quantization_param_path}")
  963. else:
  964. raise RuntimeError(
  965. "Using FP8 KV cache and scaling factors provided but "
  966. f"model {self.model.__class__} does not support loading"
  967. " scaling factors.", )
  968. else:
  969. logger.warning(
  970. "Using FP8 KV cache but no scaling factors "
  971. "provided. Defaulting to scaling factors of 1.0. "
  972. "This may lead to less accurate results!")
  973. if envs.APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
  974. logger.info("Compiling the model using torch.compile...")
  975. from aphrodite.compilation.backends import aphrodite_backend
  976. from aphrodite.plugins import get_torch_compile_backend
  977. backend = get_torch_compile_backend() or aphrodite_backend
  978. start_time = time.time()
  979. self.model = torch.compile(
  980. self.model,
  981. fullgraph=envs.APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE,
  982. backend=backend)
  983. end_time = time.time()
  984. logger.info(
  985. f"Model compiled in {end_time - start_time:.2f} seconds.")
  986. def get_model_memory_usage(self):
  987. return self.model_memory_usage
  988. def save_sharded_state(
  989. self,
  990. path: str,
  991. pattern: Optional[str] = None,
  992. max_size: Optional[int] = None,
  993. ) -> None:
  994. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  995. ShardedStateLoader.save_model(
  996. self.model,
  997. path,
  998. pattern=pattern,
  999. max_size=max_size,
  1000. )
  1001. def save_tensorized_model(
  1002. self,
  1003. tensorizer_config: TensorizerConfig,
  1004. ) -> None:
  1005. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  1006. TensorizerLoader.save_model(
  1007. self.model,
  1008. tensorizer_config=tensorizer_config,
  1009. )
  1010. def get_max_block_per_batch(self) -> int:
  1011. block_size = self.block_size
  1012. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  1013. def _prepare_model_input_tensors(
  1014. self,
  1015. seq_group_metadata_list: List[SequenceGroupMetadata],
  1016. finished_requests_ids: Optional[List[str]] = None
  1017. ) -> TModelInputForGPU:
  1018. """Helper method to prepare the model input based on a given sequence
  1019. group. Prepares metadata needed for the base model forward pass but not
  1020. metadata for possible additional steps, e.g., sampling.
  1021. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  1022. The result tensors and data structure also batches input in prefill
  1023. -> decode order. For example,
  1024. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1025. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1026. If cuda graph is required, this API automatically pads inputs.
  1027. """
  1028. builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
  1029. for seq_group_metadata in seq_group_metadata_list:
  1030. builder.add_seq_group(seq_group_metadata)
  1031. builder.reset_cached_inter_data()
  1032. return builder.build() # type: ignore
  1033. @torch.inference_mode()
  1034. def profile_run(self) -> None:
  1035. rank = get_tensor_model_parallel_rank()
  1036. if rank == 0:
  1037. logger.info("Profiling peak memory usage...")
  1038. # Enable top-k sampling to reflect the accurate memory usage.
  1039. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  1040. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  1041. max_num_seqs = self.scheduler_config.max_num_seqs
  1042. single_seq_mode = self.scheduler_config.single_user_mode
  1043. if single_seq_mode and rank == 0:
  1044. logger.info("Running in single sequence profiling mode")
  1045. # This represents the maximum number of different requests
  1046. # that will have unique loras, an therefore the max amount of memory
  1047. # consumption create dummy lora request copies from the lora request
  1048. # passed in, which contains a lora from the lora warmup path.
  1049. dummy_lora_requests: List[LoRARequest] = []
  1050. dummy_lora_requests_per_seq: List[LoRARequest] = []
  1051. if self.lora_config:
  1052. assert self.lora_manager is not None
  1053. with self.lora_manager.dummy_lora_cache():
  1054. for idx in range(self.lora_config.max_loras):
  1055. lora_id = idx + 1
  1056. dummy_lora_request = LoRARequest(
  1057. lora_name=f"warmup_{lora_id}",
  1058. lora_int_id=lora_id,
  1059. lora_local_path="/not/a/real/path",
  1060. )
  1061. self.lora_manager.add_dummy_lora(dummy_lora_request,
  1062. rank=LORA_WARMUP_RANK)
  1063. dummy_lora_requests.append(dummy_lora_request)
  1064. if single_seq_mode:
  1065. dummy_lora_requests_per_seq = [dummy_lora_requests[0]]
  1066. else:
  1067. dummy_lora_requests_per_seq = [
  1068. dummy_lora_requests[idx % len(dummy_lora_requests)]
  1069. for idx in range(max_num_seqs)
  1070. ]
  1071. # Profile memory usage with max_num_sequences sequences and the total
  1072. # number of tokens equal to max_num_batched_tokens.
  1073. seqs: List[SequenceGroupMetadata] = []
  1074. # Additional GPU memory may be needed for multi-modal encoding, which
  1075. # needs to be accounted for when calculating the GPU blocks for
  1076. # Aphrodite blocker manager.
  1077. # To exercise the worst scenario for GPU memory consumption,
  1078. # the number of seqs (batch_size) is chosen to maximize the number
  1079. # of images processed.
  1080. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
  1081. self.model_config)
  1082. if max_mm_tokens > 0 and not single_seq_mode:
  1083. max_num_seqs_orig = max_num_seqs
  1084. max_num_seqs = min(max_num_seqs,
  1085. max_num_batched_tokens // max_mm_tokens)
  1086. if max_num_seqs < 1:
  1087. expr = (f"min({max_num_seqs_orig}, "
  1088. f"{max_num_batched_tokens} // {max_mm_tokens})")
  1089. logger.warning(
  1090. f"Computed max_num_seqs ({expr}) to be less than 1. "
  1091. "Setting it to the minimum value of 1.")
  1092. max_num_seqs = 1
  1093. batch_size = 0
  1094. if single_seq_mode:
  1095. seq_len = max_num_batched_tokens
  1096. batch_size = seq_len
  1097. seq_data, dummy_multi_modal_data = self.input_registry \
  1098. .dummy_data_for_profiling(self.model_config,
  1099. seq_len,
  1100. self.mm_registry)
  1101. seq = SequenceGroupMetadata(
  1102. request_id="0",
  1103. is_prompt=True,
  1104. seq_data={0: seq_data},
  1105. sampling_params=sampling_params,
  1106. block_tables=None,
  1107. lora_request=dummy_lora_requests_per_seq[0]
  1108. if dummy_lora_requests_per_seq else None,
  1109. multi_modal_data=dummy_multi_modal_data,
  1110. )
  1111. seqs.append(seq)
  1112. else:
  1113. # Original multi-sequence profiling logic
  1114. for group_id in range(max_num_seqs):
  1115. seq_len = (max_num_batched_tokens // max_num_seqs +
  1116. (group_id < max_num_batched_tokens % max_num_seqs))
  1117. batch_size += seq_len
  1118. seq_data, dummy_multi_modal_data = self.input_registry \
  1119. .dummy_data_for_profiling(self.model_config,
  1120. seq_len,
  1121. self.mm_registry)
  1122. seq = SequenceGroupMetadata(
  1123. request_id=str(group_id),
  1124. is_prompt=True,
  1125. seq_data={group_id: seq_data},
  1126. sampling_params=sampling_params,
  1127. block_tables=None,
  1128. lora_request=dummy_lora_requests_per_seq[group_id]
  1129. if dummy_lora_requests_per_seq else None,
  1130. multi_modal_data=dummy_multi_modal_data,
  1131. )
  1132. seqs.append(seq)
  1133. # Run the model with the dummy inputs.
  1134. num_layers = self.model_config.get_num_layers(self.parallel_config)
  1135. kv_caches = [None] * num_layers
  1136. finished_requests_ids = [seq.request_id for seq in seqs]
  1137. model_input = self.prepare_model_input(
  1138. seqs, finished_requests_ids=finished_requests_ids)
  1139. intermediate_tensors = None
  1140. if not get_pp_group().is_first_rank:
  1141. intermediate_tensors = self.model.make_empty_intermediate_tensors(
  1142. batch_size=batch_size,
  1143. dtype=self.model_config.dtype,
  1144. device=self.device)
  1145. self.execute_model(model_input, kv_caches, intermediate_tensors)
  1146. torch.cuda.synchronize()
  1147. return
  1148. def remove_all_loras(self):
  1149. if not self.lora_manager:
  1150. raise RuntimeError("LoRA is not enabled.")
  1151. self.lora_manager.remove_all_adapters()
  1152. def set_active_loras(self, lora_requests: Set[LoRARequest],
  1153. lora_mapping: LoRAMapping) -> None:
  1154. if not self.lora_manager:
  1155. raise RuntimeError("LoRA is not enabled.")
  1156. self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
  1157. def add_lora(self, lora_request: LoRARequest) -> bool:
  1158. if not self.lora_manager:
  1159. raise RuntimeError("LoRA is not enabled.")
  1160. return self.lora_manager.add_adapter(lora_request)
  1161. def remove_lora(self, lora_id: int) -> bool:
  1162. if not self.lora_manager:
  1163. raise RuntimeError("LoRA is not enabled.")
  1164. return self.lora_manager.remove_adapter(lora_id)
  1165. def pin_lora(self, lora_id: int) -> bool:
  1166. if not self.lora_manager:
  1167. raise RuntimeError("LoRA is not enabled.")
  1168. return self.lora_manager.pin_adapter(lora_id)
  1169. def list_loras(self) -> Set[int]:
  1170. if not self.lora_manager:
  1171. raise RuntimeError("LoRA is not enabled.")
  1172. return self.lora_manager.list_adapters()
  1173. def remove_all_prompt_adapters(self):
  1174. if not self.prompt_adapter_manager:
  1175. raise RuntimeError("PromptAdapter is not enabled.")
  1176. self.prompt_adapter_manager.remove_all_adapters()
  1177. def set_active_prompt_adapters(
  1178. self, prompt_adapter_requests: Set[PromptAdapterRequest],
  1179. prompt_adapter_mapping: PromptAdapterMapping) -> None:
  1180. if not self.prompt_adapter_manager:
  1181. raise RuntimeError("PromptAdapter is not enabled.")
  1182. self.prompt_adapter_manager.set_active_adapters(
  1183. prompt_adapter_requests, prompt_adapter_mapping)
  1184. def add_prompt_adapter(
  1185. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  1186. if not self.prompt_adapter_manager:
  1187. raise RuntimeError("PromptAdapter is not enabled.")
  1188. return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
  1189. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1190. if not self.prompt_adapter_manager:
  1191. raise RuntimeError("PromptAdapter is not enabled.")
  1192. return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
  1193. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1194. if not self.prompt_adapter_manager:
  1195. raise RuntimeError("PromptAdapter is not enabled.")
  1196. return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
  1197. def list_prompt_adapters(self) -> Set[int]:
  1198. if not self.prompt_adapter_manager:
  1199. raise RuntimeError("PromptAdapter is not enabled.")
  1200. return self.prompt_adapter_manager.list_adapters()
  1201. @property
  1202. def model_is_mrope(self) -> bool:
  1203. """Detect if the model has "mrope" rope_scaling type.
  1204. mrope requires keep "rope_deltas" between prompt and decoding phases."""
  1205. rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
  1206. if rope_scaling is None:
  1207. return False
  1208. return rope_scaling.get("type", None) == "mrope"
  1209. @torch.inference_mode()
  1210. def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
  1211. """Cuda graph capture a model.
  1212. Note that CUDA graph's performance gain is negligible if number
  1213. of batched tokens are larger than 200. And since CUDA graph
  1214. requires fixed sized tensors, supporting large/variable batch
  1215. size requires high GPU memory overhead. Thus, Aphrodite only captures
  1216. decoding requests. Mixed batch (chunked prefill + decoding) or
  1217. prefill requests are not captured.
  1218. Since it is used for decoding-only, it assumes there's only 1 token
  1219. per sequence in the batch.
  1220. """
  1221. tp_rank = get_tensor_model_parallel_rank()
  1222. assert not self.model_config.enforce_eager
  1223. if tp_rank == 0:
  1224. logger.info(
  1225. "Capturing the model for CUDA graphs. This may lead to "
  1226. "unexpected consequences if the model is not static. To "
  1227. "run the model in eager mode, set 'enforce_eager=True' or "
  1228. "use '--enforce-eager' in the CLI.")
  1229. logger.info(
  1230. "CUDA graphs can take additional 1~3 GiB memory per GPU. "
  1231. "If you are running out of memory, consider decreasing "
  1232. "`gpu_memory_utilization` or enforcing eager mode. "
  1233. "You can also reduce the `max_num_seqs` as needed "
  1234. "to decrease memory usage.")
  1235. start_time = time.perf_counter()
  1236. # Prepare dummy inputs. These will be reused for all batch sizes.
  1237. max_batch_size = self.max_batchsize_to_capture
  1238. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  1239. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  1240. if self.model_is_mrope:
  1241. input_positions = torch.tile(input_positions, (3, 1))
  1242. # Prepare dummy previous_hidden_states only if needed by the model.
  1243. # This is used by draft models such as EAGLE.
  1244. previous_hidden_states = None
  1245. if "previous_hidden_states" in inspect.signature(
  1246. self.model.forward).parameters:
  1247. previous_hidden_states = torch.empty(
  1248. [max_batch_size,
  1249. self.model_config.get_hidden_size()],
  1250. dtype=self.model_config.dtype,
  1251. device=self.device)
  1252. intermediate_inputs = None
  1253. if not get_pp_group().is_first_rank:
  1254. intermediate_inputs = self.model.make_empty_intermediate_tensors(
  1255. batch_size=max_batch_size,
  1256. dtype=self.model_config.dtype,
  1257. device=self.device)
  1258. # Prepare buffer for outputs. These will be reused for all batch sizes.
  1259. # It will be filled after the first graph capture.
  1260. hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
  1261. None
  1262. ] * self.parallel_config.pipeline_parallel_size
  1263. graph_batch_size = self.max_batchsize_to_capture
  1264. batch_size_capture_list = [
  1265. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  1266. ]
  1267. with self.attn_state.graph_capture(
  1268. max_batch_size), graph_capture() as graph_capture_context:
  1269. # NOTE: Capturing the largest batch size first may help reduce the
  1270. # memory usage of CUDA graph.
  1271. for virtual_engine in range(
  1272. self.parallel_config.pipeline_parallel_size):
  1273. for batch_size in reversed(batch_size_capture_list):
  1274. attn_metadata = (
  1275. self.attn_state.graph_capture_get_metadata_for_batch(
  1276. batch_size,
  1277. is_encoder_decoder_model=self.model_config.
  1278. is_encoder_decoder_model))
  1279. if self.lora_config:
  1280. lora_mapping = LoRAMapping(
  1281. **dict(index_mapping=[0] * batch_size,
  1282. prompt_mapping=[0] * batch_size,
  1283. is_prefill=False))
  1284. self.set_active_loras(set(), lora_mapping)
  1285. if self.prompt_adapter_config:
  1286. prompt_adapter_mapping = PromptAdapterMapping(
  1287. [-1] * batch_size,
  1288. [-1] * batch_size,
  1289. )
  1290. self.set_active_prompt_adapters(
  1291. set(), prompt_adapter_mapping)
  1292. graph_runner = CUDAGraphRunner(
  1293. self.model, self.attn_backend.get_name(),
  1294. self.attn_state.graph_clone(batch_size),
  1295. self.model_config.is_encoder_decoder_model)
  1296. capture_inputs = {
  1297. "input_ids":
  1298. input_tokens[:batch_size],
  1299. "positions":
  1300. input_positions[..., :batch_size],
  1301. "hidden_or_intermediate_states":
  1302. hidden_or_intermediate_states[
  1303. virtual_engine] # type: ignore
  1304. [:batch_size]
  1305. if hidden_or_intermediate_states[virtual_engine]
  1306. is not None else None,
  1307. "intermediate_inputs":
  1308. intermediate_inputs[:batch_size]
  1309. if intermediate_inputs is not None else None,
  1310. "kv_caches":
  1311. kv_caches[virtual_engine],
  1312. "attn_metadata":
  1313. attn_metadata,
  1314. "memory_pool":
  1315. self.graph_memory_pool,
  1316. "stream":
  1317. graph_capture_context.stream
  1318. }
  1319. if previous_hidden_states is not None:
  1320. capture_inputs[
  1321. "previous_hidden_states"] = previous_hidden_states[:
  1322. batch_size]
  1323. if self.has_seqlen_agnostic:
  1324. # Only used by Mamba-based models CUDA graph atm (Jamba)
  1325. capture_inputs.update({
  1326. "seqlen_agnostic_capture_inputs":
  1327. self.model.get_seqlen_agnostic_capture_inputs(
  1328. batch_size)
  1329. })
  1330. if self.model_config.is_encoder_decoder_model:
  1331. # add the additional inputs to capture for
  1332. # encoder-decoder models.
  1333. self._update_inputs_to_capture_for_enc_dec_model(
  1334. capture_inputs)
  1335. graph_runner.capture(**capture_inputs)
  1336. self.graph_memory_pool = graph_runner.graph.pool()
  1337. self.graph_runners[virtual_engine][batch_size] = (
  1338. graph_runner)
  1339. end_time = time.perf_counter()
  1340. elapsed_time = end_time - start_time
  1341. # This usually takes < 10 seconds.
  1342. if tp_rank == 0:
  1343. logger.info(f"Graph capturing finished in {elapsed_time:.2f} secs")
  1344. def _update_inputs_to_capture_for_enc_dec_model(self,
  1345. capture_inputs: Dict[str,
  1346. Any]):
  1347. """
  1348. Updates the set of input tensors needed for CUDA graph capture in an
  1349. encoder-decoder model.
  1350. This method modifies the provided `capture_inputs` dictionary by
  1351. adding tensors specific to encoder-decoder specific models that
  1352. need to be captured for CUDA Graph replay.
  1353. """
  1354. # During the decode phase encoder_input_ids and encoder_positions are
  1355. # unset. Do the same thing for graph capture.
  1356. capture_inputs["encoder_input_ids"] = torch.tensor(
  1357. [], dtype=torch.long).cuda()
  1358. capture_inputs["encoder_positions"] = torch.tensor(
  1359. [], dtype=torch.long).cuda()
  1360. @property
  1361. def vocab_size(self) -> int:
  1362. return self.model_config.get_vocab_size()
  1363. class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
  1364. """
  1365. GPU model runner with sampling step.
  1366. """
  1367. _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
  1368. ModelInputForGPUWithSamplingMetadata)
  1369. _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
  1370. def make_model_input_from_broadcasted_tensor_dict(
  1371. self,
  1372. tensor_dict: Dict[str, Any],
  1373. ) -> ModelInputForGPUWithSamplingMetadata:
  1374. model_input = \
  1375. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  1376. tensor_dict,
  1377. attn_backend=self.attn_backend,
  1378. )
  1379. return model_input
  1380. def prepare_model_input(
  1381. self,
  1382. seq_group_metadata_list: List[SequenceGroupMetadata],
  1383. virtual_engine: int = 0,
  1384. finished_requests_ids: Optional[List[str]] = None
  1385. ) -> ModelInputForGPUWithSamplingMetadata:
  1386. """Prepare the model input based on a given sequence group, including
  1387. metadata for the sampling step.
  1388. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  1389. The result tensors and data structure also batches input in prefill
  1390. -> decode order. For example,
  1391. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1392. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1393. If cuda graph is required, this API automatically pads inputs.
  1394. """
  1395. model_input = self._prepare_model_input_tensors(
  1396. seq_group_metadata_list, finished_requests_ids)
  1397. if get_pp_group().is_last_rank:
  1398. # Sampling metadata is only required for the final pp group
  1399. generators = self.get_generators(finished_requests_ids)
  1400. sampling_metadata = SamplingMetadata.prepare(
  1401. seq_group_metadata_list, model_input.seq_lens,
  1402. model_input.query_lens, self.device, self.pin_memory,
  1403. generators, self.sampling_metadata_cache)
  1404. else:
  1405. sampling_metadata = None
  1406. is_prompt = (seq_group_metadata_list[0].is_prompt
  1407. if seq_group_metadata_list else None)
  1408. return dataclasses.replace(model_input,
  1409. sampling_metadata=sampling_metadata,
  1410. is_prompt=is_prompt,
  1411. virtual_engine=virtual_engine)
  1412. @torch.inference_mode()
  1413. @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
  1414. def execute_model(
  1415. self,
  1416. model_input: ModelInputForGPUWithSamplingMetadata,
  1417. kv_caches: List[torch.Tensor],
  1418. intermediate_tensors: Optional[IntermediateTensors] = None,
  1419. num_steps: int = 1,
  1420. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  1421. if num_steps > 1:
  1422. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  1423. if self.lora_config:
  1424. assert model_input.lora_requests is not None
  1425. assert model_input.lora_mapping is not None
  1426. self.set_active_loras(model_input.lora_requests,
  1427. model_input.lora_mapping)
  1428. if self.prompt_adapter_config:
  1429. assert model_input.prompt_adapter_requests is not None
  1430. assert model_input.prompt_adapter_mapping is not None
  1431. self.set_active_prompt_adapters(
  1432. model_input.prompt_adapter_requests,
  1433. model_input.prompt_adapter_mapping)
  1434. self.attn_state.begin_forward(model_input)
  1435. # Currently cuda graph is only supported by the decode phase.
  1436. assert model_input.attn_metadata is not None
  1437. prefill_meta = model_input.attn_metadata.prefill_metadata
  1438. decode_meta = model_input.attn_metadata.decode_metadata
  1439. # TODO: We can remove this once all
  1440. # virtual engines share the same kv cache.
  1441. virtual_engine = model_input.virtual_engine
  1442. if prefill_meta is None and decode_meta.use_cuda_graph:
  1443. assert model_input.input_tokens is not None
  1444. graph_batch_size = model_input.input_tokens.shape[0]
  1445. model_executable = self.graph_runners[virtual_engine][
  1446. graph_batch_size]
  1447. else:
  1448. model_executable = self.model
  1449. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  1450. seqlen_agnostic_kwargs = {
  1451. "finished_requests_ids": model_input.finished_requests_ids,
  1452. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  1453. } if self.has_seqlen_agnostic else {}
  1454. hidden_or_intermediate_states = model_executable(
  1455. input_ids=model_input.input_tokens,
  1456. positions=model_input.input_positions,
  1457. kv_caches=kv_caches,
  1458. attn_metadata=model_input.attn_metadata,
  1459. intermediate_tensors=intermediate_tensors,
  1460. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  1461. device=self.device),
  1462. **seqlen_agnostic_kwargs,
  1463. )
  1464. # Compute the logits in the last pipeline stage.
  1465. if not get_pp_group().is_last_rank:
  1466. return hidden_or_intermediate_states
  1467. logits = self.model.compute_logits(hidden_or_intermediate_states,
  1468. model_input.sampling_metadata)
  1469. if not self.is_driver_worker:
  1470. return []
  1471. if model_input.async_callback is not None:
  1472. model_input.async_callback()
  1473. # Sample the next token.
  1474. output: SamplerOutput = self.model.sample(
  1475. logits=logits,
  1476. sampling_metadata=model_input.sampling_metadata,
  1477. )
  1478. if self.return_hidden_states:
  1479. # we only need to pass hidden states of most recent token
  1480. assert model_input.sampling_metadata is not None
  1481. indices = model_input.sampling_metadata.selected_token_indices
  1482. if model_input.is_prompt:
  1483. hidden_states = hidden_or_intermediate_states.index_select(
  1484. 0, indices)
  1485. output.prefill_hidden_states = hidden_or_intermediate_states
  1486. elif decode_meta.use_cuda_graph:
  1487. hidden_states = hidden_or_intermediate_states[:len(indices)]
  1488. else:
  1489. hidden_states = hidden_or_intermediate_states
  1490. output.hidden_states = hidden_states
  1491. return [output]
  1492. class CUDAGraphRunner:
  1493. def __init__(self, model: nn.Module, backend_name: str,
  1494. attn_state: AttentionState, is_encoder_decoder_model: bool):
  1495. self.model = model
  1496. self.backend_name = backend_name
  1497. self.attn_state = attn_state
  1498. self.input_buffers: Dict[str, torch.Tensor] = {}
  1499. self.output_buffers: Dict[str, torch.Tensor] = {}
  1500. self._graph: Optional[torch.cuda.CUDAGraph] = None
  1501. self._is_encoder_decoder_model = is_encoder_decoder_model
  1502. @property
  1503. def graph(self):
  1504. assert self._graph is not None
  1505. return self._graph
  1506. def capture(
  1507. self,
  1508. input_ids: torch.Tensor,
  1509. positions: torch.Tensor,
  1510. hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
  1511. torch.Tensor]],
  1512. intermediate_inputs: Optional[IntermediateTensors],
  1513. kv_caches: List[torch.Tensor],
  1514. attn_metadata: AttentionMetadata,
  1515. memory_pool: Optional[Tuple[int, int]],
  1516. stream: torch.cuda.Stream,
  1517. **kwargs,
  1518. ) -> Union[torch.Tensor, IntermediateTensors]:
  1519. assert self._graph is None
  1520. # Run the model a few times without capturing the graph.
  1521. # This is to make sure that the captured graph does not include the
  1522. # kernel launches for initial benchmarking (e.g., Triton autotune).
  1523. # Note one iteration is not enough for torch.jit.script
  1524. for _ in range(_NUM_WARMUP_ITERS):
  1525. self.model(
  1526. input_ids=input_ids,
  1527. positions=positions,
  1528. kv_caches=kv_caches,
  1529. attn_metadata=attn_metadata,
  1530. intermediate_tensors=intermediate_inputs,
  1531. **kwargs,
  1532. )
  1533. # Wait for the warm up operations to finish before proceeding with
  1534. # Graph Capture.
  1535. torch.cuda.synchronize()
  1536. # Capture the graph.
  1537. self._graph = torch.cuda.CUDAGraph()
  1538. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  1539. output_hidden_or_intermediate_states = self.model(
  1540. input_ids=input_ids,
  1541. positions=positions,
  1542. kv_caches=kv_caches,
  1543. attn_metadata=attn_metadata,
  1544. intermediate_tensors=intermediate_inputs,
  1545. **kwargs,
  1546. )
  1547. if hidden_or_intermediate_states is not None:
  1548. if get_pp_group().is_last_rank:
  1549. hidden_or_intermediate_states.copy_(
  1550. output_hidden_or_intermediate_states)
  1551. else:
  1552. for key in hidden_or_intermediate_states.tensors:
  1553. hidden_or_intermediate_states[key].copy_(
  1554. output_hidden_or_intermediate_states[key])
  1555. else:
  1556. hidden_or_intermediate_states = (
  1557. output_hidden_or_intermediate_states)
  1558. del output_hidden_or_intermediate_states
  1559. # make sure `output_hidden_states` is deleted
  1560. # in the graph's memory pool
  1561. gc.collect()
  1562. torch.cuda.synchronize()
  1563. # Save the input and output buffers.
  1564. self.input_buffers = {
  1565. "input_ids":
  1566. input_ids,
  1567. "positions":
  1568. positions,
  1569. "kv_caches":
  1570. kv_caches,
  1571. **self.attn_state.get_graph_input_buffers(
  1572. attn_metadata, self._is_encoder_decoder_model),
  1573. **kwargs,
  1574. }
  1575. if intermediate_inputs is not None:
  1576. self.input_buffers.update(intermediate_inputs.tensors)
  1577. if get_pp_group().is_last_rank:
  1578. self.output_buffers = {
  1579. "hidden_states": hidden_or_intermediate_states
  1580. }
  1581. else:
  1582. self.output_buffers = hidden_or_intermediate_states
  1583. return hidden_or_intermediate_states
  1584. def forward(
  1585. self,
  1586. input_ids: torch.Tensor,
  1587. positions: torch.Tensor,
  1588. kv_caches: List[torch.Tensor],
  1589. attn_metadata: AttentionMetadata,
  1590. intermediate_tensors: Optional[IntermediateTensors],
  1591. **kwargs,
  1592. ) -> torch.Tensor:
  1593. # KV caches are fixed tensors, so we don't need to copy them.
  1594. del kv_caches
  1595. # Copy the input tensors to the input buffers.
  1596. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1597. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1598. if self.backend_name != "No attention":
  1599. self.input_buffers["slot_mapping"].copy_(
  1600. attn_metadata.slot_mapping, non_blocking=True)
  1601. self.attn_state.prepare_graph_input_buffers(
  1602. self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
  1603. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1604. self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
  1605. **kwargs)
  1606. if "previous_hidden_states" in self.input_buffers:
  1607. self.input_buffers["previous_hidden_states"].copy_(
  1608. kwargs["previous_hidden_states"], non_blocking=True)
  1609. if intermediate_tensors is not None:
  1610. for key in intermediate_tensors.tensors:
  1611. self.input_buffers[key].copy_(intermediate_tensors[key],
  1612. non_blocking=True)
  1613. if self._is_encoder_decoder_model:
  1614. self.input_buffers["encoder_input_ids"].copy_(
  1615. kwargs['encoder_input_ids'], non_blocking=True)
  1616. self.input_buffers["encoder_positions"].copy_(
  1617. kwargs['encoder_positions'], non_blocking=True)
  1618. # Run the graph.
  1619. self.graph.replay()
  1620. # Return the output tensor.
  1621. if get_pp_group().is_last_rank:
  1622. return self.output_buffers["hidden_states"]
  1623. return self.output_buffers
  1624. def __call__(self, *args, **kwargs):
  1625. return self.forward(*args, **kwargs)
  1626. def _get_graph_batch_size(batch_size: int) -> int:
  1627. """Returns the padded batch size given actual batch size.
  1628. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1629. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1630. """
  1631. if batch_size <= 2:
  1632. return batch_size
  1633. elif batch_size <= 4:
  1634. return 4
  1635. else:
  1636. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1637. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  1638. def _get_max_graph_batch_size(max_num_seqs: int) -> int:
  1639. """
  1640. max_num_seqs: Maximum number of sequences in a batch.
  1641. _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
  1642. pad the max_num_seqs if necessary by calling _get_graph_batch_size,
  1643. which will deal with some edge cases like 1, 2, 4.
  1644. if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
  1645. if not, it means the padded size is larger than the largest size in
  1646. _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
  1647. """
  1648. padded_size = _get_graph_batch_size(max_num_seqs)
  1649. if padded_size in _BATCH_SIZES_TO_CAPTURE:
  1650. return padded_size
  1651. assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
  1652. return _BATCH_SIZES_TO_CAPTURE[-1]