model_runner.py 73 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711
  1. import dataclasses
  2. import gc
  3. import inspect
  4. import itertools
  5. import time
  6. import warnings
  7. import weakref
  8. from dataclasses import dataclass
  9. from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
  10. Tuple, Type, TypeVar, Union)
  11. import numpy as np
  12. import torch
  13. import torch.distributed
  14. import torch.nn as nn
  15. from loguru import logger
  16. import aphrodite.common.envs as envs
  17. from aphrodite.attention import AttentionMetadata, get_attn_backend
  18. from aphrodite.attention.backends.abstract import AttentionState
  19. from aphrodite.attention.backends.utils import CommonAttentionState
  20. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  21. LoRAConfig, ModelConfig, ParallelConfig,
  22. PromptAdapterConfig, SchedulerConfig)
  23. from aphrodite.common.sampling_params import SamplingParams
  24. from aphrodite.common.sequence import (IntermediateTensors,
  25. SequenceGroupMetadata)
  26. from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
  27. async_tensor_h2d, flatten_2d_lists, is_hip,
  28. is_pin_memory_available, supports_dynamo)
  29. from aphrodite.distributed import get_pp_group
  30. from aphrodite.distributed.parallel_state import (
  31. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
  32. graph_capture)
  33. from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
  34. from aphrodite.lora.layers import LoRAMapping
  35. from aphrodite.lora.request import LoRARequest
  36. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  37. from aphrodite.modeling.layers.sampler import SamplerOutput
  38. from aphrodite.modeling.model_loader import get_model
  39. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  40. from aphrodite.modeling.models.interfaces import (supports_lora,
  41. supports_multimodal)
  42. from aphrodite.modeling.models.utils import set_cpu_offload_max_bytes
  43. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  44. SamplingMetadataCache)
  45. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  46. MultiModalInputs, MultiModalRegistry)
  47. from aphrodite.prompt_adapter.layers import PromptAdapterMapping
  48. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  49. from aphrodite.prompt_adapter.worker_manager import (
  50. LRUCacheWorkerPromptAdapterManager)
  51. from aphrodite.task_handler.model_runner_base import (
  52. ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
  53. _add_attn_metadata_broadcastable_dict,
  54. _add_sampling_metadata_broadcastable_dict,
  55. _init_attn_metadata_from_tensor_dict,
  56. _init_sampling_metadata_from_tensor_dict)
  57. if TYPE_CHECKING:
  58. from aphrodite.attention.backends.abstract import AttentionBackend
  59. LORA_WARMUP_RANK = 8
  60. _BATCH_SIZE_ALIGNMENT = 8
  61. # all the token sizes that **can** be captured by cudagraph.
  62. # they can be arbitrarily large.
  63. # currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
  64. # the actual sizes to capture will be determined by the model,
  65. # depending on the model's max_num_seqs.
  66. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  67. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  68. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
  69. ]
  70. _NUM_WARMUP_ITERS = 2
  71. TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
  72. @dataclass(frozen=True)
  73. class ModelInputForGPU(ModelRunnerInputBase):
  74. """
  75. This base class contains metadata needed for the base model forward pass
  76. but not metadata for possible additional steps, e.g., sampling. Model
  77. runners that run additional steps should subclass this method to add
  78. additional fields.
  79. """
  80. input_tokens: Optional[torch.Tensor] = None
  81. input_positions: Optional[torch.Tensor] = None
  82. seq_lens: Optional[List[int]] = None
  83. query_lens: Optional[List[int]] = None
  84. lora_mapping: Optional["LoRAMapping"] = None
  85. lora_requests: Optional[Set[LoRARequest]] = None
  86. attn_metadata: Optional["AttentionMetadata"] = None
  87. prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
  88. prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
  89. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  90. request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
  91. finished_requests_ids: Optional[List[str]] = None
  92. virtual_engine: int = 0
  93. async_callback: Optional[Callable] = None
  94. use_async_and_multi_step: bool = False
  95. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  96. tensor_dict = {
  97. "input_tokens": self.input_tokens,
  98. "input_positions": self.input_positions,
  99. "lora_requests": self.lora_requests,
  100. "lora_mapping": self.lora_mapping,
  101. "multi_modal_kwargs": self.multi_modal_kwargs,
  102. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  103. "prompt_adapter_requests": self.prompt_adapter_requests,
  104. "virtual_engine": self.virtual_engine,
  105. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  106. "finished_requests_ids": self.finished_requests_ids,
  107. }
  108. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  109. return tensor_dict
  110. @classmethod
  111. def from_broadcasted_tensor_dict(
  112. cls: Type[TModelInputForGPU],
  113. tensor_dict: Dict[str, Any],
  114. attn_backend: Optional["AttentionBackend"] = None,
  115. ) -> TModelInputForGPU:
  116. if attn_backend is not None:
  117. tensor_dict = _init_attn_metadata_from_tensor_dict(
  118. attn_backend, tensor_dict)
  119. return cls(**tensor_dict)
  120. @dataclass(frozen=True)
  121. class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
  122. """
  123. Used by the ModelRunner.
  124. """
  125. sampling_metadata: Optional["SamplingMetadata"] = None
  126. # Used for speculative decoding. We do not broadcast it because it is only
  127. # used by the driver worker.
  128. is_prompt: Optional[bool] = None
  129. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  130. tensor_dict = {
  131. "input_tokens": self.input_tokens,
  132. "input_positions": self.input_positions,
  133. "lora_requests": self.lora_requests,
  134. "lora_mapping": self.lora_mapping,
  135. "multi_modal_kwargs": self.multi_modal_kwargs,
  136. "prompt_adapter_mapping": self.prompt_adapter_mapping,
  137. "prompt_adapter_requests": self.prompt_adapter_requests,
  138. "virtual_engine": self.virtual_engine,
  139. "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
  140. "finished_requests_ids": self.finished_requests_ids,
  141. }
  142. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  143. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  144. self.sampling_metadata)
  145. return tensor_dict
  146. @classmethod
  147. def from_broadcasted_tensor_dict(
  148. cls,
  149. tensor_dict: Dict[str, Any],
  150. attn_backend: Optional["AttentionBackend"] = None,
  151. ) -> "ModelInputForGPUWithSamplingMetadata":
  152. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  153. if attn_backend is not None:
  154. tensor_dict = _init_attn_metadata_from_tensor_dict(
  155. attn_backend, tensor_dict)
  156. return cls(**tensor_dict)
  157. class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
  158. """Build ModelInputForGPU from SequenceGroupMetadata."""
  159. # NOTE: ideally, we dould be using a dataclass(kw_only=True)
  160. # here, so that this can be subclassed easily, but kw_only
  161. # is not supported in python<3.10.
  162. class InterDataForSeqGroup:
  163. """Intermediate data for the current sequence group."""
  164. def simple_reinit(self):
  165. self.input_tokens[0].clear() # type: ignore
  166. self.input_positions[0].clear() # type: ignore
  167. self.seq_lens[0] = 0 # type: ignore
  168. self.orig_seq_lens[0] = 0 # type: ignore
  169. self.query_lens[0] = 0 # type: ignore
  170. self.context_lens[0] = 0 # type: ignore
  171. self.curr_sliding_window_blocks[0] = 0 # type: ignore
  172. self.lora_index_mapping.clear() # type: ignore
  173. self.lora_prompt_mapping.clear() # type: ignore
  174. self.lora_requests.clear() # type: ignore
  175. self.prompt_adapter_index_mapping.clear() # type: ignore
  176. self.prompt_adapter_prompt_mapping.clear() # type: ignore
  177. def __init__(
  178. self,
  179. *,
  180. # From sequence group metadata.
  181. request_id: str,
  182. seq_ids: List[int],
  183. is_prompt: bool,
  184. block_tables: Optional[Dict[int, List[int]]],
  185. computed_block_nums: List[int],
  186. n_seqs: int = 0,
  187. # Input tokens and positions.
  188. input_tokens: Optional[List[List[int]]] = None,
  189. input_positions: Optional[List[List[int]]] = None,
  190. # The sequence length (may be capped to the sliding window).
  191. seq_lens: Optional[List[int]] = None,
  192. # The original sequence length (before applying sliding window).
  193. # This is used to compute slot mapping.
  194. orig_seq_lens: Optional[List[int]] = None,
  195. # The query length.
  196. query_lens: Optional[List[int]] = None,
  197. # The number of tokens that are already computed.
  198. context_lens: Optional[List[int]] = None,
  199. # The current sliding window block.
  200. curr_sliding_window_blocks: Optional[List[int]] = None,
  201. # LoRA inputs.
  202. lora_index_mapping: Optional[List[List[int]]] = None,
  203. lora_prompt_mapping: Optional[List[List[int]]] = None,
  204. lora_requests: Optional[Set[LoRARequest]] = None,
  205. # Prompt adapter inputs.
  206. prompt_adapter_index_mapping: Optional[List[int]] = None,
  207. prompt_adapter_prompt_mapping: Optional[List[int]] = None,
  208. prompt_adapter_request: Optional[PromptAdapterRequest] = None,
  209. # Multi-modal inputs.
  210. multi_modal_inputs: Optional[MultiModalInputs] = None,
  211. # Whether the prefix cache is hit (prefill only).
  212. prefix_cache_hit: bool = False,
  213. reinit: bool = False,
  214. reinit_use_defaults: bool = False,
  215. ):
  216. if reinit:
  217. assert len(self.seq_ids) == len(seq_ids) # type: ignore
  218. for i, seq_id in enumerate(seq_ids):
  219. self.seq_ids[i] = seq_id # type: ignore
  220. else:
  221. self.seq_ids = seq_ids
  222. self.request_id = request_id
  223. self.is_prompt = is_prompt
  224. self.block_tables = block_tables
  225. self.computed_block_nums = computed_block_nums
  226. self.n_seqs = n_seqs
  227. if reinit:
  228. if len(self.seq_ids) == 1 and reinit_use_defaults:
  229. self.simple_reinit()
  230. else:
  231. if input_tokens:
  232. self.input_tokens = input_tokens
  233. else:
  234. for seq_id in range(len(self.seq_ids)):
  235. self.input_tokens[seq_id].clear()
  236. if input_positions:
  237. self.input_positions = input_positions
  238. else:
  239. for seq_id in range(len(self.seq_ids)):
  240. self.input_positions[seq_id].clear()
  241. if seq_lens:
  242. self.seq_lens = seq_lens
  243. else:
  244. for seq_id in range(len(self.seq_ids)):
  245. self.seq_lens[seq_id] = 0
  246. if orig_seq_lens:
  247. self.orig_seq_lens = orig_seq_lens
  248. else:
  249. for seq_id in range(len(self.seq_ids)):
  250. self.orig_seq_lens[seq_id] = 0
  251. if query_lens:
  252. self.query_lens = query_lens
  253. else:
  254. for seq_id in range(len(self.seq_ids)):
  255. self.query_lens[seq_id] = 0
  256. if context_lens:
  257. self.context_lens = context_lens
  258. else:
  259. for seq_id in range(len(self.seq_ids)):
  260. self.context_lens[seq_id] = 0
  261. if curr_sliding_window_blocks:
  262. self.curr_sliding_window_blocks = \
  263. curr_sliding_window_blocks
  264. else:
  265. for seq_id in range(len(self.seq_ids)):
  266. self.curr_sliding_window_blocks[seq_id] = 0
  267. if lora_index_mapping:
  268. self.lora_index_mapping = lora_index_mapping
  269. else:
  270. self.lora_index_mapping.clear()
  271. if lora_prompt_mapping:
  272. self.lora_prompt_mapping = lora_prompt_mapping
  273. else:
  274. self.lora_prompt_mapping.clear()
  275. if lora_requests:
  276. self.lora_requests = lora_requests
  277. else:
  278. self.lora_requests.clear()
  279. if prompt_adapter_index_mapping:
  280. self.prompt_adapter_index_mapping = \
  281. prompt_adapter_index_mapping
  282. else:
  283. self.prompt_adapter_index_mapping.clear()
  284. if prompt_adapter_prompt_mapping:
  285. self.prompt_adapter_prompt_mapping = \
  286. prompt_adapter_prompt_mapping
  287. else:
  288. self.prompt_adapter_prompt_mapping.clear()
  289. else:
  290. self.input_tokens = input_tokens or []
  291. self.input_positions = input_positions or []
  292. self.seq_lens = seq_lens or []
  293. self.orig_seq_lens = orig_seq_lens or []
  294. self.query_lens = query_lens or []
  295. self.context_lens = context_lens or []
  296. self.curr_sliding_window_blocks = \
  297. curr_sliding_window_blocks or []
  298. self.lora_index_mapping = lora_index_mapping or []
  299. self.lora_prompt_mapping = lora_prompt_mapping or []
  300. self.lora_requests = lora_requests or set()
  301. self.prompt_adapter_index_mapping = (
  302. prompt_adapter_index_mapping or [])
  303. self.prompt_adapter_prompt_mapping = (
  304. prompt_adapter_prompt_mapping or [])
  305. self.prompt_adapter_request = prompt_adapter_request
  306. self.multi_modal_inputs = multi_modal_inputs
  307. self.prefix_cache_hit = prefix_cache_hit
  308. self.n_seqs = len(self.seq_ids)
  309. if not reinit:
  310. self.__post_init__()
  311. def __post_init__(self):
  312. self.n_seqs = len(self.seq_ids)
  313. self.input_tokens = [[] for _ in range(self.n_seqs)]
  314. self.input_positions = [[] for _ in range(self.n_seqs)]
  315. self.seq_lens = [0] * self.n_seqs
  316. self.orig_seq_lens = [0] * self.n_seqs
  317. self.query_lens = [0] * self.n_seqs
  318. self.context_lens = [0] * self.n_seqs
  319. self.curr_sliding_window_blocks = [0] * self.n_seqs
  320. self.lora_index_mapping = []
  321. self.lora_prompt_mapping = []
  322. def gen_inter_data_builder(self, num_seqs: int):
  323. return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
  324. request_id="",
  325. seq_ids=[0] * num_seqs,
  326. is_prompt=True,
  327. block_tables=None,
  328. computed_block_nums=[])
  329. def init_cached_inter_data(self, *args, **kwargs):
  330. assert len(args) == 0
  331. assert "seq_ids" in kwargs
  332. seq_ids = kwargs["seq_ids"]
  333. num_seqs = len(seq_ids)
  334. # The inter-data cache is per model_runner
  335. inter_data_cache = self.runner.inter_data_cache
  336. if num_seqs not in inter_data_cache:
  337. inter_data_cache[num_seqs] = PyObjectCache(
  338. self.gen_inter_data_builder(num_seqs))
  339. obj = inter_data_cache[num_seqs].get_object()
  340. obj.__init__(*args, **kwargs)
  341. return obj
  342. def reset_cached_inter_data(self):
  343. for cache in self.runner.inter_data_cache.values():
  344. cache.reset()
  345. def __init__(self,
  346. runner: "GPUModelRunnerBase",
  347. finished_requests_ids: Optional[List[str]] = None):
  348. super().__init__()
  349. # Compute functions for each sequence in a sequence group.
  350. # WARNING: The order of the functions matters!
  351. self.per_seq_compute_fns = [
  352. self._compute_lens,
  353. self._compute_for_prefix_cache_hit,
  354. self._compute_for_sliding_window,
  355. self._compute_lora_input,
  356. ]
  357. # Compute functions for each sequence group.
  358. # WARNING: The order of the functions matters!
  359. self.per_seq_group_compute_fns = [
  360. self._compute_prompt_adapter_input,
  361. self._compute_multi_modal_input,
  362. ]
  363. self.runner = runner
  364. self.model_input_cls = self.runner._model_input_cls
  365. self.attn_backend = self.runner.attn_backend
  366. self.scheduler_config = self.runner.scheduler_config
  367. self.sliding_window = self.runner.sliding_window
  368. self.block_size = self.runner.block_size
  369. self.enable_lora = self.runner.lora_config is not None
  370. self.enable_prompt_adapter = (self.runner.prompt_adapter_config
  371. is not None)
  372. self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
  373. self.finished_requests_ids = finished_requests_ids
  374. self.decode_only = True
  375. # Intermediate data (data in CPU before going to GPU) for
  376. # the current sequence group.
  377. self.inter_data_list: List[
  378. ModelInputForGPUBuilder.InterDataForSeqGroup] = []
  379. # Attention metadata inputs.
  380. self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
  381. weakref.proxy(self))
  382. # Engine/Model configurations.
  383. self.chunked_prefill_enabled = (
  384. self.scheduler_config is not None
  385. and self.scheduler_config.chunked_prefill_enabled)
  386. if self.sliding_window is not None:
  387. self.sliding_window_blocks = (
  388. self.sliding_window + self.block_size - 1) // self.block_size
  389. self.block_aligned_sliding_window = \
  390. self.sliding_window_blocks * self.block_size
  391. def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
  392. seq_group_metadata: SequenceGroupMetadata):
  393. """Compute context length, sequence length and tokens
  394. for the given sequence data.
  395. """
  396. seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
  397. token_chunk_size = seq_group_metadata.token_chunk_size
  398. # Compute context length (the number of tokens that are
  399. # already computed) and sequence length (total number of tokens).
  400. seq_len = seq_data.get_len()
  401. if inter_data.is_prompt:
  402. context_len = seq_data.get_num_computed_tokens()
  403. else:
  404. # get_num_computed_tokens is incorrect for spec decoding.
  405. # So, we should have a special logic here.
  406. # TODO: Fix it.
  407. context_len = seq_len - 1
  408. seq_len = min(seq_len, context_len + token_chunk_size)
  409. # Compute tokens.
  410. if inter_data.is_prompt:
  411. tokens = seq_data.get_token_ids()
  412. if context_len != 0 or seq_len < len(tokens):
  413. tokens = tokens[context_len:seq_len]
  414. else:
  415. # Optimization. get_token_ids requires the entire copy of
  416. # tokens.
  417. tokens = seq_data.get_last_token_id()
  418. inter_data.seq_lens[seq_idx] = seq_len
  419. inter_data.orig_seq_lens[seq_idx] = seq_len
  420. inter_data.context_lens[seq_idx] = context_len
  421. if isinstance(tokens, list):
  422. inter_data.input_tokens[seq_idx].extend(tokens)
  423. else:
  424. inter_data.input_tokens[seq_idx].append(tokens)
  425. if (seq_len - context_len) == 1:
  426. inter_data.input_positions[seq_idx].append(seq_len - 1)
  427. else:
  428. inter_data.input_positions[seq_idx].extend(
  429. range(context_len, seq_len))
  430. inter_data.query_lens[
  431. seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
  432. def _compute_for_prefix_cache_hit(
  433. self, inter_data: InterDataForSeqGroup, seq_idx: int,
  434. seq_group_metadata: SequenceGroupMetadata):
  435. """Check if hit prefix cache (i.e., some blocks are already computed).
  436. If hit, update input tokens and positions to only compute the
  437. remaining blocks.
  438. """
  439. computed_block_nums = inter_data.computed_block_nums
  440. # Note that prefix caching does not support sliding window.
  441. prefix_cache_hit = (computed_block_nums is not None
  442. and len(computed_block_nums) > 0
  443. and self.sliding_window is None
  444. and inter_data.is_prompt)
  445. inter_data.prefix_cache_hit = prefix_cache_hit
  446. if not prefix_cache_hit:
  447. return
  448. assert computed_block_nums is not None
  449. # The cache hit prompt tokens in this sequence. Note that
  450. # this may be larger than the sequence length if chunked
  451. # prefill is enabled.
  452. prefix_cache_len = len(computed_block_nums) * self.block_size
  453. # The number of so far computed prompt tokens in this sequence.
  454. context_len = inter_data.context_lens[seq_idx]
  455. # The total number of prompt tokens in this sequence.
  456. # When chunked prefill is enabled, this is the token number of
  457. # computed chunks + current chunk.
  458. seq_len = inter_data.seq_lens[seq_idx]
  459. if prefix_cache_len <= context_len:
  460. # We already passed the cache hit region,
  461. # so do normal computation.
  462. pass
  463. elif context_len < prefix_cache_len < seq_len:
  464. # Partial hit. Compute the missing part.
  465. uncomputed_start = prefix_cache_len - context_len
  466. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  467. seq_idx][uncomputed_start:]
  468. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  469. seq_idx][uncomputed_start:]
  470. context_len = prefix_cache_len
  471. inter_data.context_lens[seq_idx] = context_len
  472. inter_data.query_lens[
  473. seq_idx] = inter_data.seq_lens[seq_idx] - context_len
  474. elif seq_len <= prefix_cache_len:
  475. # Full hit. Only compute the last token to avoid
  476. # erroneous behavior. FIXME: Ideally we should directly
  477. # mark all tokens as computed in the scheduler and do not
  478. # schedule this sequence, so this case should not happen.
  479. inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
  480. seq_idx][-1:]
  481. inter_data.input_positions[seq_idx] = inter_data.input_positions[
  482. seq_idx][-1:]
  483. inter_data.query_lens[seq_idx] = 1
  484. inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
  485. def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
  486. seq_idx: int,
  487. seq_group_metadata: SequenceGroupMetadata):
  488. """Update seq_len and curr_sliding_window_block for the given
  489. sequence data (only required by decoding) if sliding window is enabled.
  490. """
  491. curr_sliding_window_block = 0
  492. sliding_seq_len = inter_data.seq_lens[seq_idx]
  493. if not inter_data.is_prompt and self.sliding_window is not None:
  494. # TODO: This is a hack to make sliding window work with
  495. # paged attn. We can remove it if we make paged attn kernel
  496. # to properly handle slinding window attn.
  497. curr_sliding_window_block = self.sliding_window_blocks
  498. if self.scheduler_config.use_v2_block_manager:
  499. # number of elements in last block
  500. suff_len = inter_data.seq_lens[seq_idx] % self.block_size
  501. sliding_seq_len = min(
  502. inter_data.seq_lens[seq_idx],
  503. self.block_aligned_sliding_window + suff_len)
  504. if suff_len > 0:
  505. curr_sliding_window_block += 1
  506. else:
  507. sliding_seq_len = min(inter_data.seq_lens[seq_idx],
  508. self.sliding_window)
  509. inter_data.curr_sliding_window_blocks[
  510. seq_idx] = curr_sliding_window_block
  511. inter_data.seq_lens[seq_idx] = sliding_seq_len
  512. def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
  513. seq_idx: int,
  514. seq_group_metadata: SequenceGroupMetadata):
  515. """If LoRA is enabled, compute LoRA index and prompt mapping."""
  516. if not self.enable_lora:
  517. return
  518. lora_id = seq_group_metadata.lora_int_id
  519. if lora_id > 0:
  520. inter_data.lora_requests.add(seq_group_metadata.lora_request)
  521. query_len = inter_data.query_lens[seq_idx]
  522. inter_data.lora_index_mapping.append([lora_id] * query_len)
  523. sampling_params = seq_group_metadata.sampling_params
  524. if sampling_params and sampling_params.prompt_logprobs is not None:
  525. inter_data.lora_prompt_mapping.append([lora_id] * query_len)
  526. elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
  527. inter_data.lora_prompt_mapping.append([lora_id])
  528. else:
  529. inter_data.lora_prompt_mapping.append([])
  530. def _compute_prompt_adapter_input(
  531. self, inter_data: InterDataForSeqGroup,
  532. seq_group_metadata: SequenceGroupMetadata):
  533. """If prompt adapter is enabled, compute index and prompt mapping.
  534. """
  535. # Note that when is_prompt=True, we expect only one sequence
  536. # in the group.
  537. if not self.enable_prompt_adapter:
  538. return
  539. prompt_adapter_id = seq_group_metadata.prompt_adapter_id
  540. if prompt_adapter_id <= 0 or not inter_data.is_prompt:
  541. return
  542. # We expect only one sequence in the group when is_prompt=True.
  543. assert inter_data.n_seqs == 1
  544. query_len = inter_data.query_lens[0]
  545. inter_data.prompt_adapter_request = (
  546. seq_group_metadata.prompt_adapter_request)
  547. num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
  548. inter_data.prompt_adapter_index_mapping = [
  549. prompt_adapter_id
  550. ] * num_tokens + [0] * (query_len - num_tokens)
  551. inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
  552. query_len if seq_group_metadata.sampling_params
  553. and seq_group_metadata.sampling_params.prompt_logprobs else 1)
  554. def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
  555. seq_group_metadata: SequenceGroupMetadata):
  556. """If multi-modal data is given, add it to the input."""
  557. mm_data = seq_group_metadata.multi_modal_data
  558. if not mm_data:
  559. return
  560. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  561. inter_data.multi_modal_inputs = mm_kwargs
  562. def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
  563. """Add a sequence group to the builder."""
  564. seq_ids = seq_group_metadata.seq_data.keys()
  565. n_seqs = len(seq_ids)
  566. is_prompt = seq_group_metadata.is_prompt
  567. if is_prompt:
  568. assert n_seqs == 1
  569. self.decode_only = False
  570. inter_data = self.init_cached_inter_data(
  571. request_id=seq_group_metadata.request_id,
  572. seq_ids=seq_ids,
  573. is_prompt=is_prompt,
  574. block_tables=seq_group_metadata.block_tables,
  575. computed_block_nums=seq_group_metadata.computed_block_nums,
  576. reinit=True,
  577. reinit_use_defaults=True)
  578. self.inter_data_list.append(inter_data)
  579. for seq_idx in range(n_seqs):
  580. for per_seq_fn in self.per_seq_compute_fns:
  581. per_seq_fn(inter_data, seq_idx, seq_group_metadata)
  582. for per_seq_group_fn in self.per_seq_group_compute_fns:
  583. per_seq_group_fn(inter_data, seq_group_metadata)
  584. def _use_captured_graph(self, batch_size: int,
  585. max_decode_seq_len: int) -> bool:
  586. return (self.decode_only and not self.runner.model_config.enforce_eager
  587. and batch_size <= self.runner.max_batchsize_to_capture
  588. and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
  589. def build(self) -> ModelInputForGPU:
  590. """Finalize the builder intermediate data and
  591. create on-device tensors.
  592. """
  593. # Combine and flatten intermediate data.
  594. input_tokens = []
  595. for inter_data in self.inter_data_list:
  596. for cur_input_tokens in inter_data.input_tokens:
  597. input_tokens.extend(cur_input_tokens)
  598. if not input_tokens:
  599. # This may happen when all prefill requests hit
  600. # prefix caching and there is no decode request.
  601. return self.model_input_cls()
  602. input_positions = []
  603. for inter_data in self.inter_data_list:
  604. for cur_input_positions in inter_data.input_positions:
  605. input_positions.extend(cur_input_positions)
  606. seq_lens = []
  607. max_decode_seq_len = 0
  608. for inter_data in self.inter_data_list:
  609. seq_lens.extend(inter_data.seq_lens)
  610. if not inter_data.is_prompt:
  611. max_decode_seq_len = max(max_decode_seq_len,
  612. max(inter_data.seq_lens))
  613. query_lens = []
  614. for inter_data in self.inter_data_list:
  615. query_lens.extend(inter_data.query_lens)
  616. # Mapping from request IDs to sequence IDs. Used for Jamba models
  617. # that manages the cache by itself.
  618. request_ids_to_seq_ids = {
  619. data.request_id: data.seq_ids
  620. for data in self.inter_data_list
  621. }
  622. batch_size = len(input_tokens)
  623. use_captured_graph = self._use_captured_graph(batch_size,
  624. max_decode_seq_len)
  625. # If cuda graph can be used, pad tensors accordingly.
  626. # See `capture_model` API for more details.
  627. # Aphrodite uses cuda graph only for decoding requests.
  628. cuda_graph_pad_size = -1
  629. if use_captured_graph:
  630. graph_batch_size = _get_graph_batch_size(batch_size)
  631. assert graph_batch_size >= batch_size
  632. cuda_graph_pad_size = graph_batch_size - batch_size
  633. batch_size = graph_batch_size
  634. # Tokens and positions.
  635. if cuda_graph_pad_size:
  636. input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
  637. input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
  638. assert self.runner.device is not None
  639. input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
  640. self.runner.device,
  641. self.runner.pin_memory)
  642. input_positions_tensor = async_tensor_h2d(input_positions, torch.long,
  643. self.runner.device,
  644. self.runner.pin_memory)
  645. # Sequence and query lengths.
  646. if cuda_graph_pad_size:
  647. seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
  648. # Attention metadata.
  649. attn_metadata = self.attn_metadata_builder.build(
  650. seq_lens, query_lens, cuda_graph_pad_size, batch_size)
  651. # LoRA data.
  652. lora_requests = set()
  653. lora_mapping = None
  654. if self.enable_lora:
  655. lora_requests = set(r for data in self.inter_data_list
  656. for r in data.lora_requests)
  657. lora_index_mapping = flatten_2d_lists([
  658. flatten_2d_lists(inter_data.lora_index_mapping)
  659. for inter_data in self.inter_data_list
  660. ])
  661. if cuda_graph_pad_size:
  662. lora_index_mapping.extend(
  663. itertools.repeat(0, cuda_graph_pad_size))
  664. lora_prompt_mapping = flatten_2d_lists([
  665. flatten_2d_lists(inter_data.lora_prompt_mapping)
  666. for inter_data in self.inter_data_list
  667. ])
  668. lora_mapping = LoRAMapping(
  669. **dict(index_mapping=lora_index_mapping,
  670. prompt_mapping=lora_prompt_mapping,
  671. is_prefill=not self.decode_only))
  672. # Prompt adapter data.
  673. prompt_adapter_requests: Set[PromptAdapterRequest] = set()
  674. prompt_adapter_mapping = None
  675. if self.enable_prompt_adapter:
  676. prompt_adapter_requests = set(
  677. data.prompt_adapter_request for data in self.inter_data_list
  678. if data.prompt_adapter_request is not None)
  679. prompt_adapter_index_mapping = flatten_2d_lists([
  680. inter_data.prompt_adapter_index_mapping
  681. for inter_data in self.inter_data_list
  682. ])
  683. if cuda_graph_pad_size:
  684. prompt_adapter_index_mapping.extend(
  685. itertools.repeat(0, cuda_graph_pad_size))
  686. prompt_adapter_prompt_mapping = flatten_2d_lists([
  687. inter_data.prompt_adapter_prompt_mapping
  688. for inter_data in self.inter_data_list
  689. ])
  690. prompt_adapter_mapping = PromptAdapterMapping(
  691. prompt_adapter_index_mapping,
  692. prompt_adapter_prompt_mapping,
  693. )
  694. # Multi-modal data.
  695. multi_modal_inputs_list = [
  696. data.multi_modal_inputs for data in self.inter_data_list
  697. if data.multi_modal_inputs is not None
  698. ]
  699. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  700. return self.model_input_cls(
  701. input_tokens=input_tokens_tensor,
  702. input_positions=input_positions_tensor,
  703. attn_metadata=attn_metadata,
  704. seq_lens=seq_lens,
  705. query_lens=query_lens,
  706. lora_mapping=lora_mapping,
  707. lora_requests=lora_requests,
  708. multi_modal_kwargs=multi_modal_kwargs,
  709. request_ids_to_seq_ids=request_ids_to_seq_ids,
  710. finished_requests_ids=self.finished_requests_ids,
  711. prompt_adapter_mapping=prompt_adapter_mapping,
  712. prompt_adapter_requests=prompt_adapter_requests)
  713. class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
  714. """
  715. Helper class for shared methods between GPU model runners.
  716. """
  717. _model_input_cls: Type[TModelInputForGPU]
  718. _builder_cls: Type[ModelInputForGPUBuilder]
  719. def __init__(
  720. self,
  721. model_config: ModelConfig,
  722. parallel_config: ParallelConfig,
  723. scheduler_config: SchedulerConfig,
  724. device_config: DeviceConfig,
  725. cache_config: CacheConfig,
  726. load_config: LoadConfig,
  727. lora_config: Optional[LoRAConfig],
  728. kv_cache_dtype: Optional[str] = "auto",
  729. is_driver_worker: bool = False,
  730. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  731. return_hidden_states: bool = False,
  732. input_registry: InputRegistry = INPUT_REGISTRY,
  733. mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
  734. tp_rank: int = 0,
  735. ):
  736. self.model_config = model_config
  737. self.parallel_config = parallel_config
  738. self.scheduler_config = scheduler_config
  739. self.device_config = device_config
  740. self.cache_config = cache_config
  741. self.lora_config = lora_config
  742. self.load_config = load_config
  743. self.is_driver_worker = is_driver_worker
  744. self.prompt_adapter_config = prompt_adapter_config
  745. self.return_hidden_states = return_hidden_states
  746. self.device = self.device_config.device
  747. self.pin_memory = is_pin_memory_available()
  748. self.tp_rank = tp_rank
  749. self.kv_cache_dtype = kv_cache_dtype
  750. self.sliding_window = model_config.get_sliding_window()
  751. self.block_size = cache_config.block_size
  752. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  753. self.max_batchsize_to_capture = _get_max_graph_batch_size(
  754. self.scheduler_config.max_num_seqs)
  755. self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
  756. {} for _ in range(self.parallel_config.pipeline_parallel_size)
  757. ]
  758. self.graph_memory_pool: Optional[Tuple[
  759. int, int]] = None # Set during graph capture.
  760. self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
  761. parallel_config)
  762. # When using CUDA graph, the input block tables must be padded to
  763. # max_seq_len_to_capture. However, creating the block table in
  764. # Python can be expensive. To optimize this, we cache the block table
  765. # in numpy and only copy the actual input content at every iteration.
  766. # The shape of the cached block table will be
  767. # (max batch size to capture, max context len to capture / block size).
  768. self.graph_block_tables = np.zeros(
  769. (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
  770. dtype=np.int32)
  771. self.attn_backend = get_attn_backend(
  772. self.model_config.get_head_size(),
  773. self.model_config.get_sliding_window(),
  774. self.model_config.dtype,
  775. self.kv_cache_dtype,
  776. self.block_size,
  777. self.model_config.is_attention_free(),
  778. )
  779. if self.attn_backend:
  780. self.attn_state = self.attn_backend.get_state_cls()(
  781. weakref.proxy(self))
  782. else:
  783. self.attn_state = CommonAttentionState(weakref.proxy(self))
  784. # Multi-modal data support
  785. self.input_registry = input_registry
  786. self.mm_registry = mm_registry
  787. self.multi_modal_input_mapper = mm_registry \
  788. .create_input_mapper(model_config)
  789. self.mm_registry.init_mm_limits_per_prompt(self.model_config)
  790. # Lazy initialization
  791. self.model: nn.Module # Set after load_model
  792. # Set after load_model.
  793. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  794. self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
  795. set_cpu_offload_max_bytes(
  796. int(self.cache_config.cpu_offload_gb * 1024**3))
  797. # Used to cache python objects
  798. self.inter_data_cache: Dict[int, PyObjectCache] = {}
  799. self.sampling_metadata_cache: SamplingMetadataCache = \
  800. SamplingMetadataCache()
  801. def load_model(self) -> None:
  802. tp = get_tensor_model_parallel_world_size()
  803. rank = get_tensor_model_parallel_rank()
  804. if rank == 0:
  805. logger.info(f"Loading model {self.model_config.model}...")
  806. with CudaMemoryProfiler() as m:
  807. # measure the time it takes to load the model
  808. start_time = time.time()
  809. self.model = get_model(model_config=self.model_config,
  810. device_config=self.device_config,
  811. load_config=self.load_config,
  812. lora_config=self.lora_config,
  813. parallel_config=self.parallel_config,
  814. scheduler_config=self.scheduler_config,
  815. cache_config=self.cache_config)
  816. end_time = time.time()
  817. self.model_memory_usage = m.consumed_memory
  818. total_time = end_time - start_time
  819. if tp > 1:
  820. if rank == 0:
  821. logger.info(f"Model loaded in {total_time:.2f} seconds.")
  822. logger.info(
  823. "Total model weights memory usage: "
  824. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  825. else:
  826. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  827. logger.info(
  828. "Total model weights memory usage: "
  829. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  830. if self.lora_config:
  831. assert supports_lora(self.model), "Model does not support LoRA"
  832. assert not supports_multimodal(
  833. self.model
  834. ), "To be tested: multimodal language model with LoRA settings."
  835. self.lora_manager = LRUCacheWorkerLoRAManager(
  836. self.scheduler_config.max_num_seqs,
  837. self.scheduler_config.max_num_batched_tokens,
  838. self.vocab_size,
  839. self.lora_config,
  840. self.device,
  841. self.model.embedding_modules,
  842. self.model.embedding_padding_modules,
  843. max_position_embeddings=self.model.config.
  844. max_position_embeddings,
  845. )
  846. self.model = self.lora_manager.create_lora_manager(self.model)
  847. if self.prompt_adapter_config:
  848. self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
  849. self.scheduler_config.max_num_seqs,
  850. self.scheduler_config.max_num_batched_tokens, self.device,
  851. self.prompt_adapter_config)
  852. self.model = (
  853. self.prompt_adapter_manager.create_prompt_adapter_manager(
  854. self.model))
  855. if self.kv_cache_dtype == "fp8" and is_hip():
  856. # Currently only ROCm accepts kv-cache scaling factors
  857. # via quantization_param_path and this will be deprecated
  858. # in the future.
  859. if self.model_config.quantization_param_path is not None:
  860. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  861. warnings.warn(
  862. "Loading kv cache scaling factor from JSON is "
  863. "deprecated and will be removed. Please include "
  864. "kv cache scaling factors in the model checkpoint.",
  865. FutureWarning,
  866. stacklevel=2)
  867. self.model.load_kv_cache_scales(
  868. self.model_config.quantization_param_path)
  869. logger.info(
  870. "Loaded KV cache scaling factors from ",
  871. f"{self.model_config.quantization_param_path}")
  872. else:
  873. raise RuntimeError(
  874. "Using FP8 KV cache and scaling factors provided but "
  875. f"model {self.model.__class__} does not support loading"
  876. " scaling factors.", )
  877. else:
  878. logger.warning(
  879. "Using FP8 KV cache but no scaling factors "
  880. "provided. Defaulting to scaling factors of 1.0. "
  881. "This may lead to less accurate results!")
  882. if envs.APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo:
  883. logger.info("Compiling the model using torch.compile...")
  884. start_time = time.time()
  885. self.model = torch.compile(self.model,
  886. fullgraph=True,
  887. backend="eager")
  888. end_time = time.time()
  889. logger.info(
  890. f"Model compiled in {end_time - start_time:.2f} seconds.")
  891. def get_model_memory_usage(self):
  892. return self.model_memory_usage
  893. def save_sharded_state(
  894. self,
  895. path: str,
  896. pattern: Optional[str] = None,
  897. max_size: Optional[int] = None,
  898. ) -> None:
  899. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  900. ShardedStateLoader.save_model(
  901. self.model,
  902. path,
  903. pattern=pattern,
  904. max_size=max_size,
  905. )
  906. def save_tensorized_model(
  907. self,
  908. tensorizer_config: TensorizerConfig,
  909. ) -> None:
  910. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  911. TensorizerLoader.save_model(
  912. self.model,
  913. tensorizer_config=tensorizer_config,
  914. )
  915. def get_max_block_per_batch(self) -> int:
  916. block_size = self.block_size
  917. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  918. def _prepare_model_input_tensors(
  919. self,
  920. seq_group_metadata_list: List[SequenceGroupMetadata],
  921. finished_requests_ids: Optional[List[str]] = None
  922. ) -> TModelInputForGPU:
  923. """Helper method to prepare the model input based on a given sequence
  924. group. Prepares metadata needed for the base model forward pass but not
  925. metadata for possible additional steps, e.g., sampling.
  926. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  927. The result tensors and data structure also batches input in prefill
  928. -> decode order. For example,
  929. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  930. - input_tokens[num_prefill_tokens:] contains decode tokens.
  931. If cuda graph is required, this API automatically pads inputs.
  932. """
  933. builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
  934. for seq_group_metadata in seq_group_metadata_list:
  935. builder.add_seq_group(seq_group_metadata)
  936. builder.reset_cached_inter_data()
  937. return builder.build() # type: ignore
  938. @torch.inference_mode()
  939. def profile_run(self) -> None:
  940. rank = get_tensor_model_parallel_rank()
  941. if rank == 0:
  942. logger.info("Profiling peak memory usage...")
  943. # Enable top-k sampling to reflect the accurate memory usage.
  944. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  945. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  946. max_num_seqs = self.scheduler_config.max_num_seqs
  947. single_seq_mode = self.scheduler_config.single_user_mode
  948. if single_seq_mode and rank == 0:
  949. logger.info("Running in single sequence profiling mode")
  950. # This represents the maximum number of different requests
  951. # that will have unique loras, an therefore the max amount of memory
  952. # consumption create dummy lora request copies from the lora request
  953. # passed in, which contains a lora from the lora warmup path.
  954. dummy_lora_requests: List[LoRARequest] = []
  955. dummy_lora_requests_per_seq: List[LoRARequest] = []
  956. if self.lora_config:
  957. assert self.lora_manager is not None
  958. with self.lora_manager.dummy_lora_cache():
  959. for idx in range(self.lora_config.max_loras):
  960. lora_id = idx + 1
  961. dummy_lora_request = LoRARequest(
  962. lora_name=f"warmup_{lora_id}",
  963. lora_int_id=lora_id,
  964. lora_local_path="/not/a/real/path",
  965. )
  966. self.lora_manager.add_dummy_lora(dummy_lora_request,
  967. rank=LORA_WARMUP_RANK)
  968. dummy_lora_requests.append(dummy_lora_request)
  969. if single_seq_mode:
  970. dummy_lora_requests_per_seq = [dummy_lora_requests[0]]
  971. else:
  972. dummy_lora_requests_per_seq = [
  973. dummy_lora_requests[idx % len(dummy_lora_requests)]
  974. for idx in range(max_num_seqs)
  975. ]
  976. # Profile memory usage with max_num_sequences sequences and the total
  977. # number of tokens equal to max_num_batched_tokens.
  978. seqs: List[SequenceGroupMetadata] = []
  979. # Additional GPU memory may be needed for multi-modal encoding, which
  980. # needs to be accounted for when calculating the GPU blocks for
  981. # Aphrodite blocker manager.
  982. # To exercise the worst scenario for GPU memory consumption,
  983. # the number of seqs (batch_size) is chosen to maximize the number
  984. # of images processed.
  985. max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
  986. self.model_config)
  987. if max_mm_tokens > 0 and not single_seq_mode:
  988. max_num_seqs_orig = max_num_seqs
  989. max_num_seqs = min(max_num_seqs,
  990. max_num_batched_tokens // max_mm_tokens)
  991. if max_num_seqs < 1:
  992. expr = (f"min({max_num_seqs_orig}, "
  993. f"{max_num_batched_tokens} // {max_mm_tokens})")
  994. logger.warning(
  995. f"Computed max_num_seqs ({expr}) to be less than 1. "
  996. "Setting it to the minimum value of 1.")
  997. max_num_seqs = 1
  998. batch_size = 0
  999. if single_seq_mode:
  1000. seq_len = max_num_batched_tokens
  1001. batch_size = seq_len
  1002. seq_data, dummy_multi_modal_data = self.input_registry \
  1003. .dummy_data_for_profiling(self.model_config,
  1004. seq_len,
  1005. self.mm_registry)
  1006. seq = SequenceGroupMetadata(
  1007. request_id="0",
  1008. is_prompt=True,
  1009. seq_data={0: seq_data},
  1010. sampling_params=sampling_params,
  1011. block_tables=None,
  1012. lora_request=dummy_lora_requests_per_seq[0]
  1013. if dummy_lora_requests_per_seq else None,
  1014. multi_modal_data=dummy_multi_modal_data,
  1015. )
  1016. seqs.append(seq)
  1017. else:
  1018. # Original multi-sequence profiling logic
  1019. for group_id in range(max_num_seqs):
  1020. seq_len = (max_num_batched_tokens // max_num_seqs +
  1021. (group_id < max_num_batched_tokens % max_num_seqs))
  1022. batch_size += seq_len
  1023. seq_data, dummy_multi_modal_data = self.input_registry \
  1024. .dummy_data_for_profiling(self.model_config,
  1025. seq_len,
  1026. self.mm_registry)
  1027. seq = SequenceGroupMetadata(
  1028. request_id=str(group_id),
  1029. is_prompt=True,
  1030. seq_data={group_id: seq_data},
  1031. sampling_params=sampling_params,
  1032. block_tables=None,
  1033. lora_request=dummy_lora_requests_per_seq[group_id]
  1034. if dummy_lora_requests_per_seq else None,
  1035. multi_modal_data=dummy_multi_modal_data,
  1036. )
  1037. seqs.append(seq)
  1038. # Run the model with the dummy inputs.
  1039. num_layers = self.model_config.get_num_layers(self.parallel_config)
  1040. kv_caches = [None] * num_layers
  1041. finished_requests_ids = [seq.request_id for seq in seqs]
  1042. model_input = self.prepare_model_input(
  1043. seqs, finished_requests_ids=finished_requests_ids)
  1044. intermediate_tensors = None
  1045. if not get_pp_group().is_first_rank:
  1046. intermediate_tensors = self.model.make_empty_intermediate_tensors(
  1047. batch_size=batch_size,
  1048. dtype=self.model_config.dtype,
  1049. device=self.device)
  1050. self.execute_model(model_input, kv_caches, intermediate_tensors)
  1051. torch.cuda.synchronize()
  1052. return
  1053. def remove_all_loras(self):
  1054. if not self.lora_manager:
  1055. raise RuntimeError("LoRA is not enabled.")
  1056. self.lora_manager.remove_all_adapters()
  1057. def set_active_loras(self, lora_requests: Set[LoRARequest],
  1058. lora_mapping: LoRAMapping) -> None:
  1059. if not self.lora_manager:
  1060. raise RuntimeError("LoRA is not enabled.")
  1061. self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
  1062. def add_lora(self, lora_request: LoRARequest) -> bool:
  1063. if not self.lora_manager:
  1064. raise RuntimeError("LoRA is not enabled.")
  1065. return self.lora_manager.add_adapter(lora_request)
  1066. def remove_lora(self, lora_id: int) -> bool:
  1067. if not self.lora_manager:
  1068. raise RuntimeError("LoRA is not enabled.")
  1069. return self.lora_manager.remove_adapter(lora_id)
  1070. def pin_lora(self, lora_id: int) -> bool:
  1071. if not self.lora_manager:
  1072. raise RuntimeError("LoRA is not enabled.")
  1073. return self.lora_manager.pin_adapter(lora_id)
  1074. def list_loras(self) -> Set[int]:
  1075. if not self.lora_manager:
  1076. raise RuntimeError("LoRA is not enabled.")
  1077. return self.lora_manager.list_adapters()
  1078. def remove_all_prompt_adapters(self):
  1079. if not self.prompt_adapter_manager:
  1080. raise RuntimeError("PromptAdapter is not enabled.")
  1081. self.prompt_adapter_manager.remove_all_adapters()
  1082. def set_active_prompt_adapters(
  1083. self, prompt_adapter_requests: Set[PromptAdapterRequest],
  1084. prompt_adapter_mapping: PromptAdapterMapping) -> None:
  1085. if not self.prompt_adapter_manager:
  1086. raise RuntimeError("PromptAdapter is not enabled.")
  1087. self.prompt_adapter_manager.set_active_adapters(
  1088. prompt_adapter_requests, prompt_adapter_mapping)
  1089. def add_prompt_adapter(
  1090. self, prompt_adapter_request: PromptAdapterRequest) -> bool:
  1091. if not self.prompt_adapter_manager:
  1092. raise RuntimeError("PromptAdapter is not enabled.")
  1093. return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
  1094. def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1095. if not self.prompt_adapter_manager:
  1096. raise RuntimeError("PromptAdapter is not enabled.")
  1097. return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
  1098. def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
  1099. if not self.prompt_adapter_manager:
  1100. raise RuntimeError("PromptAdapter is not enabled.")
  1101. return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
  1102. def list_prompt_adapters(self) -> Set[int]:
  1103. if not self.prompt_adapter_manager:
  1104. raise RuntimeError("PromptAdapter is not enabled.")
  1105. return self.prompt_adapter_manager.list_adapters()
  1106. @torch.inference_mode()
  1107. def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
  1108. """Cuda graph capture a model.
  1109. Note that CUDA graph's performance gain is negligible if number
  1110. of batched tokens are larger than 200. And since CUDA graph
  1111. requires fixed sized tensors, supporting large/variable batch
  1112. size requires high GPU memory overhead. Thus, Aphrodite only captures
  1113. decoding requests. Mixed batch (chunked prefill + decoding) or
  1114. prefill requests are not captured.
  1115. Since it is used for decoding-only, it assumes there's only 1 token
  1116. per sequence in the batch.
  1117. """
  1118. tp_rank = get_tensor_model_parallel_rank()
  1119. assert not self.model_config.enforce_eager
  1120. if tp_rank == 0:
  1121. logger.info(
  1122. "Capturing the model for CUDA graphs. This may lead to "
  1123. "unexpected consequences if the model is not static. To "
  1124. "run the model in eager mode, set 'enforce_eager=True' or "
  1125. "use '--enforce-eager' in the CLI.")
  1126. logger.info(
  1127. "CUDA graphs can take additional 1~3 GiB memory per GPU. "
  1128. "If you are running out of memory, consider decreasing "
  1129. "`gpu_memory_utilization` or enforcing eager mode. "
  1130. "You can also reduce the `max_num_seqs` as needed "
  1131. "to decrease memory usage.")
  1132. start_time = time.perf_counter()
  1133. # Prepare dummy inputs. These will be reused for all batch sizes.
  1134. max_batch_size = self.max_batchsize_to_capture
  1135. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  1136. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  1137. # Prepare dummy previous_hidden_states only if needed by the model.
  1138. # This is used by draft models such as EAGLE.
  1139. previous_hidden_states = None
  1140. if "previous_hidden_states" in inspect.signature(
  1141. self.model.forward).parameters:
  1142. previous_hidden_states = torch.empty(
  1143. [max_batch_size,
  1144. self.model_config.get_hidden_size()],
  1145. dtype=self.model_config.dtype,
  1146. device=self.device)
  1147. intermediate_inputs = None
  1148. if not get_pp_group().is_first_rank:
  1149. intermediate_inputs = self.model.make_empty_intermediate_tensors(
  1150. batch_size=max_batch_size,
  1151. dtype=self.model_config.dtype,
  1152. device=self.device)
  1153. # Prepare buffer for outputs. These will be reused for all batch sizes.
  1154. # It will be filled after the first graph capture.
  1155. hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
  1156. None
  1157. ] * self.parallel_config.pipeline_parallel_size
  1158. graph_batch_size = self.max_batchsize_to_capture
  1159. batch_size_capture_list = [
  1160. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  1161. ]
  1162. with self.attn_state.graph_capture(
  1163. max_batch_size), graph_capture() as graph_capture_context:
  1164. # NOTE: Capturing the largest batch size first may help reduce the
  1165. # memory usage of CUDA graph.
  1166. for virtual_engine in range(
  1167. self.parallel_config.pipeline_parallel_size):
  1168. for batch_size in reversed(batch_size_capture_list):
  1169. attn_metadata = (
  1170. self.attn_state.graph_capture_get_metadata_for_batch(
  1171. batch_size))
  1172. if self.lora_config:
  1173. lora_mapping = LoRAMapping(
  1174. **dict(index_mapping=[0] * batch_size,
  1175. prompt_mapping=[0] * batch_size,
  1176. is_prefill=False))
  1177. self.set_active_loras(set(), lora_mapping)
  1178. if self.prompt_adapter_config:
  1179. prompt_adapter_mapping = PromptAdapterMapping(
  1180. [-1] * batch_size,
  1181. [-1] * batch_size,
  1182. )
  1183. self.set_active_prompt_adapters(
  1184. set(), prompt_adapter_mapping)
  1185. graph_runner = CUDAGraphRunner(
  1186. self.model, self.attn_backend.get_name(),
  1187. self.attn_state.graph_clone(batch_size))
  1188. capture_inputs = {
  1189. "input_ids":
  1190. input_tokens[:batch_size],
  1191. "positions":
  1192. input_positions[:batch_size],
  1193. "hidden_or_intermediate_states":
  1194. hidden_or_intermediate_states[
  1195. virtual_engine] # type: ignore
  1196. [:batch_size]
  1197. if hidden_or_intermediate_states[virtual_engine]
  1198. is not None else None,
  1199. "intermediate_inputs":
  1200. intermediate_inputs[:batch_size]
  1201. if intermediate_inputs is not None else None,
  1202. "kv_caches":
  1203. kv_caches[virtual_engine],
  1204. "attn_metadata":
  1205. attn_metadata,
  1206. "memory_pool":
  1207. self.graph_memory_pool,
  1208. "stream":
  1209. graph_capture_context.stream
  1210. }
  1211. if previous_hidden_states is not None:
  1212. capture_inputs[
  1213. "previous_hidden_states"] = previous_hidden_states[:
  1214. batch_size]
  1215. if self.has_seqlen_agnostic:
  1216. # Only used by Mamba-based models CUDA graph atm (Jamba)
  1217. capture_inputs.update({
  1218. "seqlen_agnostic_capture_inputs":
  1219. self.model.get_seqlen_agnostic_capture_inputs(
  1220. batch_size)
  1221. })
  1222. graph_runner.capture(**capture_inputs)
  1223. self.graph_memory_pool = graph_runner.graph.pool()
  1224. self.graph_runners[virtual_engine][batch_size] = (
  1225. graph_runner)
  1226. end_time = time.perf_counter()
  1227. elapsed_time = end_time - start_time
  1228. # This usually takes < 10 seconds.
  1229. if tp_rank == 0:
  1230. logger.info(f"Graph capturing finished in {elapsed_time:.2f} secs")
  1231. @property
  1232. def vocab_size(self) -> int:
  1233. return self.model_config.get_vocab_size()
  1234. class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
  1235. """
  1236. GPU model runner with sampling step.
  1237. """
  1238. _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
  1239. ModelInputForGPUWithSamplingMetadata)
  1240. _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
  1241. def make_model_input_from_broadcasted_tensor_dict(
  1242. self,
  1243. tensor_dict: Dict[str, Any],
  1244. ) -> ModelInputForGPUWithSamplingMetadata:
  1245. model_input = \
  1246. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  1247. tensor_dict,
  1248. attn_backend=self.attn_backend,
  1249. )
  1250. return model_input
  1251. def prepare_model_input(
  1252. self,
  1253. seq_group_metadata_list: List[SequenceGroupMetadata],
  1254. virtual_engine: int = 0,
  1255. finished_requests_ids: Optional[List[str]] = None
  1256. ) -> ModelInputForGPUWithSamplingMetadata:
  1257. """Prepare the model input based on a given sequence group, including
  1258. metadata for the sampling step.
  1259. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  1260. The result tensors and data structure also batches input in prefill
  1261. -> decode order. For example,
  1262. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1263. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1264. If cuda graph is required, this API automatically pads inputs.
  1265. """
  1266. model_input = self._prepare_model_input_tensors(
  1267. seq_group_metadata_list, finished_requests_ids)
  1268. if get_pp_group().is_last_rank:
  1269. # Sampling metadata is only required for the final pp group
  1270. generators = self.get_generators(finished_requests_ids)
  1271. sampling_metadata = SamplingMetadata.prepare(
  1272. seq_group_metadata_list, model_input.seq_lens,
  1273. model_input.query_lens, self.device, self.pin_memory,
  1274. generators, self.sampling_metadata_cache)
  1275. else:
  1276. sampling_metadata = None
  1277. is_prompt = (seq_group_metadata_list[0].is_prompt
  1278. if seq_group_metadata_list else None)
  1279. return dataclasses.replace(model_input,
  1280. sampling_metadata=sampling_metadata,
  1281. is_prompt=is_prompt,
  1282. virtual_engine=virtual_engine)
  1283. @torch.inference_mode()
  1284. def execute_model(
  1285. self,
  1286. model_input: ModelInputForGPUWithSamplingMetadata,
  1287. kv_caches: List[torch.Tensor],
  1288. intermediate_tensors: Optional[IntermediateTensors] = None,
  1289. num_steps: int = 1,
  1290. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  1291. if num_steps > 1:
  1292. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  1293. if self.lora_config:
  1294. assert model_input.lora_requests is not None
  1295. assert model_input.lora_mapping is not None
  1296. self.set_active_loras(model_input.lora_requests,
  1297. model_input.lora_mapping)
  1298. if self.prompt_adapter_config:
  1299. assert model_input.prompt_adapter_requests is not None
  1300. assert model_input.prompt_adapter_mapping is not None
  1301. self.set_active_prompt_adapters(
  1302. model_input.prompt_adapter_requests,
  1303. model_input.prompt_adapter_mapping)
  1304. self.attn_state.begin_forward(model_input)
  1305. # Currently cuda graph is only supported by the decode phase.
  1306. assert model_input.attn_metadata is not None
  1307. prefill_meta = model_input.attn_metadata.prefill_metadata
  1308. decode_meta = model_input.attn_metadata.decode_metadata
  1309. # TODO: We can remove this once all
  1310. # virtual engines share the same kv cache.
  1311. virtual_engine = model_input.virtual_engine
  1312. if prefill_meta is None and decode_meta.use_cuda_graph:
  1313. assert model_input.input_tokens is not None
  1314. graph_batch_size = model_input.input_tokens.shape[0]
  1315. model_executable = self.graph_runners[virtual_engine][
  1316. graph_batch_size]
  1317. else:
  1318. model_executable = self.model
  1319. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  1320. seqlen_agnostic_kwargs = {
  1321. "finished_requests_ids": model_input.finished_requests_ids,
  1322. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  1323. } if self.has_seqlen_agnostic else {}
  1324. hidden_or_intermediate_states = model_executable(
  1325. input_ids=model_input.input_tokens,
  1326. positions=model_input.input_positions,
  1327. kv_caches=kv_caches,
  1328. attn_metadata=model_input.attn_metadata,
  1329. intermediate_tensors=intermediate_tensors,
  1330. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  1331. device=self.device),
  1332. **seqlen_agnostic_kwargs,
  1333. )
  1334. # Compute the logits in the last pipeline stage.
  1335. if not get_pp_group().is_last_rank:
  1336. return hidden_or_intermediate_states
  1337. logits = self.model.compute_logits(hidden_or_intermediate_states,
  1338. model_input.sampling_metadata)
  1339. if not self.is_driver_worker:
  1340. return []
  1341. if model_input.async_callback is not None:
  1342. model_input.async_callback()
  1343. # Sample the next token.
  1344. output: SamplerOutput = self.model.sample(
  1345. logits=logits,
  1346. sampling_metadata=model_input.sampling_metadata,
  1347. )
  1348. if self.return_hidden_states:
  1349. # we only need to pass hidden states of most recent token
  1350. assert model_input.sampling_metadata is not None
  1351. indices = model_input.sampling_metadata.selected_token_indices
  1352. if model_input.is_prompt:
  1353. hidden_states = hidden_or_intermediate_states.index_select(
  1354. 0, indices)
  1355. output.prefill_hidden_states = hidden_or_intermediate_states
  1356. elif decode_meta.use_cuda_graph:
  1357. hidden_states = hidden_or_intermediate_states[:len(indices)]
  1358. else:
  1359. hidden_states = hidden_or_intermediate_states
  1360. output.hidden_states = hidden_states
  1361. return [output]
  1362. class CUDAGraphRunner:
  1363. def __init__(self, model: nn.Module, backend_name: str,
  1364. attn_state: AttentionState):
  1365. self.model = model
  1366. self.backend_name = backend_name
  1367. self.attn_state = attn_state
  1368. self.input_buffers: Dict[str, torch.Tensor] = {}
  1369. self.output_buffers: Dict[str, torch.Tensor] = {}
  1370. self._graph: Optional[torch.cuda.CUDAGraph] = None
  1371. @property
  1372. def graph(self):
  1373. assert self._graph is not None
  1374. return self._graph
  1375. def capture(
  1376. self,
  1377. input_ids: torch.Tensor,
  1378. positions: torch.Tensor,
  1379. hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
  1380. torch.Tensor]],
  1381. intermediate_inputs: Optional[IntermediateTensors],
  1382. kv_caches: List[torch.Tensor],
  1383. attn_metadata: AttentionMetadata,
  1384. memory_pool: Optional[Tuple[int, int]],
  1385. stream: torch.cuda.Stream,
  1386. **kwargs,
  1387. ) -> Union[torch.Tensor, IntermediateTensors]:
  1388. assert self._graph is None
  1389. # Run the model a few times without capturing the graph.
  1390. # This is to make sure that the captured graph does not include the
  1391. # kernel launches for initial benchmarking (e.g., Triton autotune).
  1392. # Note one iteration is not enough for torch.jit.script
  1393. for _ in range(_NUM_WARMUP_ITERS):
  1394. self.model(
  1395. input_ids=input_ids,
  1396. positions=positions,
  1397. kv_caches=kv_caches,
  1398. attn_metadata=attn_metadata,
  1399. intermediate_tensors=intermediate_inputs,
  1400. **kwargs,
  1401. )
  1402. torch.cuda.synchronize()
  1403. # Capture the graph.
  1404. self._graph = torch.cuda.CUDAGraph()
  1405. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  1406. output_hidden_or_intermediate_states = self.model(
  1407. input_ids=input_ids,
  1408. positions=positions,
  1409. kv_caches=kv_caches,
  1410. attn_metadata=attn_metadata,
  1411. intermediate_tensors=intermediate_inputs,
  1412. **kwargs,
  1413. )
  1414. if hidden_or_intermediate_states is not None:
  1415. if get_pp_group().is_last_rank:
  1416. hidden_or_intermediate_states.copy_(
  1417. output_hidden_or_intermediate_states)
  1418. else:
  1419. for key in hidden_or_intermediate_states.tensors:
  1420. hidden_or_intermediate_states[key].copy_(
  1421. output_hidden_or_intermediate_states[key])
  1422. else:
  1423. hidden_or_intermediate_states = (
  1424. output_hidden_or_intermediate_states)
  1425. del output_hidden_or_intermediate_states
  1426. # make sure `output_hidden_states` is deleted
  1427. # in the graph's memory pool
  1428. gc.collect()
  1429. torch.cuda.synchronize()
  1430. # Save the input and output buffers.
  1431. self.input_buffers = {
  1432. "input_ids": input_ids,
  1433. "positions": positions,
  1434. "kv_caches": kv_caches,
  1435. **self.attn_state.get_graph_input_buffers(attn_metadata),
  1436. **kwargs,
  1437. }
  1438. if intermediate_inputs is not None:
  1439. self.input_buffers.update(intermediate_inputs.tensors)
  1440. if get_pp_group().is_last_rank:
  1441. self.output_buffers = {
  1442. "hidden_states": hidden_or_intermediate_states
  1443. }
  1444. else:
  1445. self.output_buffers = hidden_or_intermediate_states
  1446. return hidden_or_intermediate_states
  1447. def forward(
  1448. self,
  1449. input_ids: torch.Tensor,
  1450. positions: torch.Tensor,
  1451. kv_caches: List[torch.Tensor],
  1452. attn_metadata: AttentionMetadata,
  1453. intermediate_tensors: Optional[IntermediateTensors],
  1454. **kwargs,
  1455. ) -> torch.Tensor:
  1456. # KV caches are fixed tensors, so we don't need to copy them.
  1457. del kv_caches
  1458. # Copy the input tensors to the input buffers.
  1459. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1460. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1461. if self.backend_name != "No attention":
  1462. self.input_buffers["slot_mapping"].copy_(
  1463. attn_metadata.slot_mapping, non_blocking=True)
  1464. self.attn_state.prepare_graph_input_buffers(self.input_buffers,
  1465. attn_metadata)
  1466. if "seqlen_agnostic_capture_inputs" in self.input_buffers:
  1467. self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
  1468. **kwargs)
  1469. if "previous_hidden_states" in self.input_buffers:
  1470. self.input_buffers["previous_hidden_states"].copy_(
  1471. kwargs["previous_hidden_states"], non_blocking=True)
  1472. if intermediate_tensors is not None:
  1473. for key in intermediate_tensors.tensors:
  1474. self.input_buffers[key].copy_(intermediate_tensors[key],
  1475. non_blocking=True)
  1476. # Run the graph.
  1477. self.graph.replay()
  1478. # Return the output tensor.
  1479. if get_pp_group().is_last_rank:
  1480. return self.output_buffers["hidden_states"]
  1481. return self.output_buffers
  1482. def __call__(self, *args, **kwargs):
  1483. return self.forward(*args, **kwargs)
  1484. def _get_graph_batch_size(batch_size: int) -> int:
  1485. """Returns the padded batch size given actual batch size.
  1486. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1487. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1488. """
  1489. if batch_size <= 2:
  1490. return batch_size
  1491. elif batch_size <= 4:
  1492. return 4
  1493. else:
  1494. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1495. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  1496. def _get_max_graph_batch_size(max_num_seqs: int) -> int:
  1497. """
  1498. max_num_seqs: Maximum number of sequences in a batch.
  1499. _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
  1500. pad the max_num_seqs if necessary by calling _get_graph_batch_size,
  1501. which will deal with some edge cases like 1, 2, 4.
  1502. if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
  1503. if not, it means the padded size is larger than the largest size in
  1504. _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
  1505. """
  1506. padded_size = _get_graph_batch_size(max_num_seqs)
  1507. if padded_size in _BATCH_SIZES_TO_CAPTURE:
  1508. return padded_size
  1509. assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
  1510. return _BATCH_SIZES_TO_CAPTURE[-1]