model_runner.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219
  1. import contextlib
  2. import time
  3. from enum import IntEnum
  4. from typing import Dict, List, NamedTuple, Optional, Set, Tuple
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from loguru import logger
  9. from aphrodite.attention import (AttentionMetadata, AttentionMetadataPerStage,
  10. get_attn_backend)
  11. from aphrodite.common.config import (DeviceConfig, LoadConfig, LoRAConfig,
  12. ModelConfig, ParallelConfig,
  13. SchedulerConfig, VisionLanguageConfig)
  14. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  15. from aphrodite.common.sequence import (MultiModalData, SamplerOutput,
  16. SequenceData, SequenceGroupMetadata)
  17. from aphrodite.common.utils import (CudaMemoryProfiler, async_tensor_h2d,
  18. is_hip, is_pin_memory_available,
  19. make_tensor_with_pad, maybe_expand_dim)
  20. from aphrodite.distributed import (broadcast_tensor_dict,
  21. get_tensor_model_parallel_world_size,
  22. with_pynccl_for_all_reduce)
  23. from aphrodite.distributed.device_communicators import (custom_all_reduce,
  24. pynccl_utils)
  25. from aphrodite.lora.layers import LoRAMapping
  26. from aphrodite.lora.request import LoRARequest
  27. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  28. from aphrodite.modeling import SamplingMetadata
  29. from aphrodite.modeling.model_loader import get_model
  30. from aphrodite.modeling.sampling_metadata import PersistentMetadata
  31. _PAD_SLOT_ID = -1
  32. LORA_WARMUP_RANK = 8
  33. _BATCH_SIZE_ALIGNMENT = 8
  34. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  35. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  36. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  37. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  38. ]
  39. class PreparePromptMetadata(NamedTuple):
  40. input_tokens: List[int]
  41. input_positions: List[int]
  42. attn_metadata: Optional[AttentionMetadataPerStage]
  43. prompt_lens: List[int]
  44. subquery_lens: List[int]
  45. lora_index_mapping: List[int]
  46. lora_prompt_mapping: List[int]
  47. lora_requests: Set[LoRARequest]
  48. multi_modal_input: Optional[torch.Tensor]
  49. slot_mapping: List[int]
  50. @classmethod
  51. def empty(cls):
  52. return PreparePromptMetadata(
  53. input_tokens=[],
  54. input_positions=[],
  55. attn_metadata=None,
  56. prompt_lens=[],
  57. subquery_lens=[],
  58. lora_index_mapping=[],
  59. lora_prompt_mapping=[],
  60. lora_requests=set(),
  61. multi_modal_input=None,
  62. slot_mapping=[],
  63. )
  64. class PrepareDecodeMetadata(NamedTuple):
  65. input_tokens: List[int]
  66. input_positions: List[int]
  67. attn_metadata: Optional[AttentionMetadata]
  68. lora_index_mapping: List[int]
  69. lora_prompt_mapping: List[int]
  70. lora_requests: Set[LoRARequest]
  71. slot_mapping: List[int]
  72. @classmethod
  73. def empty(cls):
  74. return PrepareDecodeMetadata(
  75. input_tokens=[],
  76. input_positions=[],
  77. attn_metadata=None,
  78. lora_index_mapping=[],
  79. lora_prompt_mapping=[],
  80. lora_requests=set(),
  81. slot_mapping=[],
  82. )
  83. # How batches are constructed.
  84. class BatchType(IntEnum):
  85. # Every batch is prefill.
  86. PREFILL = 0
  87. # Every batch is decode.
  88. DECODE = 1
  89. # Batch is a mixture of prefill and decode.
  90. MIXED = 2
  91. class ModelRunner:
  92. def __init__(
  93. self,
  94. model_config: ModelConfig,
  95. parallel_config: ParallelConfig,
  96. scheduler_config: SchedulerConfig,
  97. device_config: DeviceConfig,
  98. load_config: LoadConfig,
  99. lora_config: Optional[LoRAConfig],
  100. kv_cache_dtype: Optional[str] = "auto",
  101. is_driver_worker: bool = False,
  102. vision_language_config: Optional[VisionLanguageConfig] = None,
  103. ):
  104. self.model_config = model_config
  105. self.parallel_config = parallel_config
  106. self.scheduler_config = scheduler_config
  107. self.lora_config = lora_config
  108. self.load_config = load_config
  109. self.is_driver_worker = is_driver_worker
  110. # model_config can be None in tests/samplers/test_sampler.py.
  111. # FIXME: This is a hack to make the tests work. Refactor this.
  112. self.sliding_window = (model_config.get_sliding_window()
  113. if model_config is not None else None)
  114. self.device_config = (device_config
  115. if device_config is not None else DeviceConfig())
  116. self.device = self.device_config.device
  117. # Set after load_model.
  118. self.lora_manager: LRUCacheWorkerLoRAManager = None
  119. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  120. self.graph_memory_pool: Optional[Tuple[
  121. int, int]] = None # Set during graph capture.
  122. self.max_context_len_to_capture = (
  123. self.model_config.max_context_len_to_capture
  124. if self.model_config is not None else 0)
  125. self.pin_memory = is_pin_memory_available()
  126. self.kv_cache_dtype = kv_cache_dtype
  127. self.vision_language_config = vision_language_config
  128. self.attn_backend = get_attn_backend(
  129. self.model_config.dtype if model_config is not None else None)
  130. # Lazy initialization
  131. self.model: torch.nn.Module # Set after load_model
  132. self.block_size: int # Set after initial profiling.
  133. # When using CUDA graph, the input block tables must be padded to
  134. # max_context_len_to_capture. However, creating the block table in
  135. # Python can be expensive. To optimize this, we cache the block table
  136. # in numpy and only copy the actual input content at every iteration.
  137. # The shape of the cached block table will be
  138. # (max batch size to capture, max context len to capture / block size).
  139. self.graph_block_tables: torch.Tensor # Set after initial profiling.
  140. def load_model(self) -> None:
  141. with CudaMemoryProfiler() as m:
  142. self.model = get_model(
  143. model_config=self.model_config,
  144. device_config=self.device_config,
  145. load_config=self.load_config,
  146. lora_config=self.lora_config,
  147. vision_language_config=self.vision_language_config,
  148. parallel_config=self.parallel_config,
  149. scheduler_config=self.scheduler_config,
  150. )
  151. self.model_memory_usage = m.consumed_memory
  152. tp = get_tensor_model_parallel_world_size()
  153. logger.info(
  154. "Model weights loaded. Memory usage: "
  155. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
  156. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  157. if self.lora_config:
  158. assert hasattr(self.model, "supported_lora_modules"
  159. ) and self.model.supported_lora_modules, (
  160. "Model does not support LoRA")
  161. assert hasattr(
  162. self.model,
  163. "embedding_modules"), "Model does not have embedding_modules"
  164. assert hasattr(self.model, "embedding_padding_modules"
  165. ), "Model does not have embedding_padding_modules"
  166. self.lora_manager = LRUCacheWorkerLoRAManager(
  167. self.scheduler_config.max_num_seqs,
  168. self.scheduler_config.max_num_batched_tokens, self.vocab_size,
  169. self.lora_config, self.device, self.model.embedding_modules,
  170. self.model.embedding_padding_modules)
  171. self.model = self.lora_manager.create_lora_manager(self.model)
  172. if self.kv_cache_dtype == "fp8" and is_hip():
  173. # Currently scaled KV cache is only enabled on ROCm
  174. if self.model_config.quantization_param_path is not None:
  175. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  176. self.model.load_kv_cache_scales(
  177. self.model_config.quantization_param_path)
  178. else:
  179. raise RuntimeError("Using FP8 KV cache and scaling "
  180. "factors provided but model "
  181. f"{self.model.__class__} does not "
  182. "support loading scaling factors.")
  183. else:
  184. logger.warn("Using FP8 KV cache but no scaling factors "
  185. "provided. Defaulting to scaling factors of 1.0. "
  186. "This may lead to less accurate results!")
  187. elif self.model_config.quantization_param_path is not None:
  188. logger.warn("KV cache scaling factors provided, "
  189. "but the KV cache data type is not FP8. "
  190. "KV cache scaling factors will not be used.")
  191. def set_block_size(self, block_size: int) -> None:
  192. self.block_size = block_size
  193. self.graph_block_tables = np.zeros(
  194. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  195. dtype=np.int32)
  196. def get_max_block_per_batch(self) -> int:
  197. block_size = self.block_size
  198. return (self.max_context_len_to_capture + block_size - 1) // block_size
  199. def _prepare_prompt(
  200. self,
  201. seq_group_metadata_list: List[SequenceGroupMetadata],
  202. ) -> PreparePromptMetadata:
  203. input_tokens: List[int] = []
  204. input_positions: List[int] = []
  205. slot_mapping: List[int] = []
  206. lora_index_mapping: List[int] = []
  207. lora_prompt_mapping: List[int] = []
  208. lora_requests: Set[LoRARequest] = set()
  209. prompt_lens: List[int] = []
  210. context_lens: List[int] = []
  211. subquery_lens: List[int] = []
  212. prefix_block_tables: List[List[int]] = []
  213. multi_modal_input_list: List[torch.Tensor] = []
  214. if len(seq_group_metadata_list) == 0:
  215. return PreparePromptMetadata.empty()
  216. for seq_group_metadata in seq_group_metadata_list:
  217. assert seq_group_metadata.is_prompt
  218. seq_ids = list(seq_group_metadata.seq_data.keys())
  219. assert len(seq_ids) == 1
  220. seq_id = seq_ids[0]
  221. computed_block_nums = seq_group_metadata.computed_block_nums
  222. if (self.scheduler_config is not None
  223. and self.scheduler_config.chunked_prefill_enabled
  224. and not (computed_block_nums is None
  225. or computed_block_nums == [])):
  226. raise RuntimeError(
  227. "chunked prefill cannot be used with prefix caching "
  228. "now.")
  229. token_chunk_size = seq_group_metadata.token_chunk_size
  230. seq_data = seq_group_metadata.seq_data[seq_id]
  231. computed_len = seq_data.get_num_computed_tokens()
  232. # We should use get_len here because in case of preemption
  233. # it contains output tokens.
  234. prefill_end = min(seq_data.get_len(),
  235. computed_len + token_chunk_size)
  236. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
  237. prompt_len = prefill_end
  238. prompt_lens.append(prompt_len)
  239. # NOTE: This only works for oooooooxxx style attention.
  240. if computed_block_nums is not None and len(
  241. computed_block_nums) > 0 and self.sliding_window is None:
  242. # Prefix is not supported with sliding_window
  243. computed_len = len(computed_block_nums) * self.block_size
  244. prompt_tokens = prompt_tokens[computed_len:]
  245. prefix_block_tables.append(computed_block_nums)
  246. elif self.scheduler_config.chunked_prefill_enabled:
  247. if seq_group_metadata.block_tables is not None:
  248. # Prefill has chunked before.
  249. block_table = seq_group_metadata.block_tables[seq_id]
  250. prefix_block_tables.append(block_table)
  251. else:
  252. # The first prefill.
  253. prefix_block_tables.append([])
  254. else:
  255. prefix_block_tables.append([])
  256. # Right now, prefill start is always 0. However, this
  257. # assumption can be changed once chunked prefill is introduced.
  258. assert computed_len == 0
  259. # actual prompt lens
  260. context_lens.append(computed_len)
  261. subquery_lens.append(prompt_len - computed_len)
  262. input_tokens.extend(prompt_tokens)
  263. # NOTE: Here we assume that the first token in the prompt
  264. # is always the first token in the sequence.
  265. input_positions.extend(list(range(computed_len, prefill_end)))
  266. lora_id = seq_group_metadata.lora_int_id
  267. if lora_id > 0:
  268. lora_requests.add(seq_group_metadata.lora_request)
  269. lora_index_mapping += [lora_id] * (prompt_len - computed_len)
  270. lora_prompt_mapping.extend(
  271. [lora_id] *
  272. (prompt_len - computed_len
  273. if seq_group_metadata.sampling_params.prompt_logprobs else 1))
  274. if seq_group_metadata.multi_modal_data:
  275. multi_modal_input_list.append(
  276. seq_group_metadata.multi_modal_data.data)
  277. if seq_group_metadata.block_tables is None:
  278. # During memory profiling, the block tables are not initialized
  279. # yet. In this case, we just use a dummy slot mapping.
  280. slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
  281. continue
  282. # Compute the slot mapping.
  283. block_table = seq_group_metadata.block_tables[seq_id]
  284. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  285. # where start_idx is max(0, prompt_len - sliding_window).
  286. # For example, if the prompt len is 10, sliding window is 8, and
  287. # block size is 4, the first two tokens are masked and the slot
  288. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  289. start_idx = 0
  290. if self.sliding_window is not None:
  291. assert computed_len == 0, (
  292. "Prefix caching is currently not supported with "
  293. "sliding window attention")
  294. start_idx = max(0, prompt_len - self.sliding_window)
  295. for i in range(computed_len, prefill_end):
  296. if i < start_idx:
  297. slot_mapping.append(_PAD_SLOT_ID)
  298. continue
  299. block_number = block_table[i // self.block_size]
  300. block_offset = i % self.block_size
  301. slot = block_number * self.block_size + block_offset
  302. slot_mapping.append(slot)
  303. max_subquery_len = max(subquery_lens)
  304. max_prompt_len = max(prompt_lens)
  305. assert max_subquery_len > 0
  306. context_lens_tensor = torch.tensor(context_lens,
  307. dtype=torch.int,
  308. device=self.device)
  309. if multi_modal_input_list:
  310. assert self.vision_language_config, (
  311. "Multi-modal inputs are only supported by "
  312. "vision language models.")
  313. multi_modal_input = torch.cat(multi_modal_input_list,
  314. dim=0).to(self.device)
  315. else:
  316. multi_modal_input = None
  317. # Prepare prefix block tables
  318. max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
  319. block_tables = make_tensor_with_pad(
  320. prefix_block_tables,
  321. max_len=max_prompt_block_table_len,
  322. pad=0,
  323. dtype=torch.int,
  324. device=self.device,
  325. )
  326. # Query length can be shorter than key (i.e., prompt) when prefill
  327. # is chunked or prefix cached.
  328. subquery_lens_tensor = torch.tensor(subquery_lens,
  329. dtype=torch.long,
  330. device=self.device)
  331. subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
  332. dtype=torch.int32,
  333. device=self.device)
  334. prompt_lens_tensor = torch.tensor(prompt_lens,
  335. dtype=torch.long,
  336. device=self.device)
  337. seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
  338. dtype=torch.int32,
  339. device=self.device)
  340. torch.cumsum(subquery_lens_tensor,
  341. dim=0,
  342. dtype=subquery_start_loc.dtype,
  343. out=subquery_start_loc[1:])
  344. torch.cumsum(prompt_lens_tensor,
  345. dim=0,
  346. dtype=seq_start_loc.dtype,
  347. out=seq_start_loc[1:])
  348. attn_metadata = self.attn_backend.make_metadata(
  349. is_prompt=True,
  350. prompt_lens=prompt_lens,
  351. prompt_lens_tensor=prompt_lens_tensor,
  352. max_subquery_len=max_subquery_len,
  353. max_context_len=None,
  354. max_prompt_len=max_prompt_len,
  355. subquery_start_loc=subquery_start_loc,
  356. seq_start_loc=seq_start_loc,
  357. context_lens=context_lens_tensor,
  358. block_tables=block_tables,
  359. use_cuda_graph=False,
  360. )
  361. return PreparePromptMetadata(
  362. input_tokens=input_tokens,
  363. input_positions=input_positions,
  364. attn_metadata=attn_metadata,
  365. prompt_lens=prompt_lens,
  366. subquery_lens=subquery_lens,
  367. lora_index_mapping=lora_index_mapping,
  368. lora_prompt_mapping=lora_prompt_mapping,
  369. lora_requests=lora_requests,
  370. multi_modal_input=multi_modal_input,
  371. slot_mapping=slot_mapping,
  372. )
  373. def _prepare_decode(
  374. self,
  375. seq_group_metadata_list: List[SequenceGroupMetadata],
  376. ) -> PrepareDecodeMetadata:
  377. input_tokens: List[int] = []
  378. input_positions: List[int] = []
  379. slot_mapping: List[int] = []
  380. context_lens: List[int] = []
  381. block_tables: List[List[int]] = []
  382. lora_index_mapping: List[int] = []
  383. lora_prompt_mapping: List[int] = []
  384. lora_requests: Set[LoRARequest] = set()
  385. if len(seq_group_metadata_list) == 0:
  386. return PrepareDecodeMetadata.empty()
  387. for seq_group_metadata in seq_group_metadata_list:
  388. assert not seq_group_metadata.is_prompt
  389. assert seq_group_metadata.token_chunk_size == 1
  390. seq_ids = list(seq_group_metadata.seq_data.keys())
  391. lora_id = seq_group_metadata.lora_int_id
  392. if lora_id > 0:
  393. lora_requests.add(seq_group_metadata.lora_request)
  394. for seq_id in seq_ids:
  395. seq_data = seq_group_metadata.seq_data[seq_id]
  396. generation_token = seq_data.get_last_token_id()
  397. input_tokens.append(generation_token)
  398. seq_len = seq_data.get_len()
  399. position = seq_len - 1
  400. input_positions.append(position)
  401. context_len = seq_len if self.sliding_window is None else min(
  402. seq_len, self.sliding_window)
  403. context_lens.append(context_len)
  404. block_table = seq_group_metadata.block_tables[seq_id]
  405. block_number = block_table[position // self.block_size]
  406. block_offset = position % self.block_size
  407. slot = block_number * self.block_size + block_offset
  408. slot_mapping.append(slot)
  409. lora_index_mapping.append(lora_id)
  410. lora_prompt_mapping.append(lora_id)
  411. if self.sliding_window is not None:
  412. sliding_window_blocks = (self.sliding_window //
  413. self.block_size)
  414. block_table = block_table[-sliding_window_blocks:]
  415. block_tables.append(block_table)
  416. # Aphrodite uses cuda graph only for decoding requests.
  417. # See `capture_model` API for more details.
  418. # For decoding requests, batch_size == input_tokens.
  419. batch_size = len(input_tokens)
  420. max_context_len = max(context_lens)
  421. use_captured_graph = (
  422. not self.model_config.enforce_eager
  423. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  424. and max_context_len <= self.max_context_len_to_capture)
  425. if use_captured_graph:
  426. graph_batch_size = _get_graph_batch_size(batch_size)
  427. assert graph_batch_size >= batch_size
  428. for _ in range(graph_batch_size - batch_size):
  429. input_tokens.append(0)
  430. input_positions.append(0)
  431. slot_mapping.append(_PAD_SLOT_ID)
  432. context_lens.append(1)
  433. block_tables.append([])
  434. lora_index_mapping.append(0)
  435. batch_size = graph_batch_size
  436. context_lens_tensor = torch.tensor(context_lens,
  437. dtype=torch.int,
  438. device=self.device)
  439. if use_captured_graph:
  440. # When using cuda-graph all these tensors should be
  441. # padded.
  442. assert context_lens_tensor.shape[0] == len(input_tokens)
  443. assert context_lens_tensor.shape[0] == len(input_positions)
  444. assert context_lens_tensor.shape[0] == len(slot_mapping)
  445. # The shape of graph_block_tables is
  446. # [max batch size, max context len // block size].
  447. input_block_tables = self.graph_block_tables[:batch_size]
  448. for i, block_table in enumerate(block_tables):
  449. if block_table:
  450. input_block_tables[i, :len(block_table)] = block_table
  451. block_tables = torch.tensor(input_block_tables, device=self.device)
  452. else:
  453. max_block_table_len = max(
  454. len(block_table) for block_table in block_tables)
  455. block_tables = make_tensor_with_pad(
  456. block_tables,
  457. max_len=max_block_table_len,
  458. pad=0,
  459. dtype=torch.int,
  460. device=self.device,
  461. )
  462. attn_metadata = self.attn_backend.make_metadata(
  463. is_prompt=False,
  464. prompt_lens=None,
  465. prompt_lens_tensor=None,
  466. max_subquery_len=None,
  467. max_context_len=max_context_len,
  468. max_prompt_len=None,
  469. subquery_start_loc=None,
  470. seq_start_loc=None,
  471. context_lens=context_lens_tensor,
  472. block_tables=block_tables,
  473. use_cuda_graph=use_captured_graph,
  474. )
  475. return PrepareDecodeMetadata(
  476. input_tokens=input_tokens,
  477. input_positions=input_positions,
  478. attn_metadata=attn_metadata,
  479. lora_index_mapping=lora_index_mapping,
  480. lora_prompt_mapping=lora_prompt_mapping,
  481. lora_requests=lora_requests,
  482. slot_mapping=slot_mapping,
  483. )
  484. def _prepare_sample(
  485. self,
  486. seq_group_metadata_list: List[SequenceGroupMetadata],
  487. prompt_lens: List[int],
  488. subquery_lens: Optional[List[int]],
  489. ) -> SamplingMetadata:
  490. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  491. selected_token_indices: List[int] = []
  492. generators: List[torch.Generator] = []
  493. selected_token_start_idx = 0
  494. categorized_sample_indices: Dict[SamplingType,
  495. List[Tuple[int, int]]] = {
  496. t: []
  497. for t in SamplingType
  498. }
  499. categorized_sample_indices_start_idx = 0
  500. categorized_sampled_token_indices_start_idx = 0
  501. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  502. seq_ids = list(seq_group_metadata.seq_data.keys())
  503. sampling_params = seq_group_metadata.sampling_params
  504. seq_groups.append((seq_ids, sampling_params))
  505. if seq_group_metadata.is_prompt:
  506. assert len(seq_ids) == 1
  507. assert subquery_lens is not None
  508. subquery_len = subquery_lens[i]
  509. if sampling_params.prompt_logprobs is not None:
  510. # NOTE: prompt token positions do not need sample, skip
  511. categorized_sample_indices_start_idx += subquery_len - 1
  512. categorized_sample_indices[
  513. sampling_params.sampling_type].append(
  514. (categorized_sample_indices_start_idx,
  515. categorized_sampled_token_indices_start_idx))
  516. categorized_sample_indices_start_idx += 1
  517. categorized_sampled_token_indices_start_idx += 1
  518. if sampling_params.prompt_logprobs is not None:
  519. selected_token_indices.extend(
  520. range(selected_token_start_idx,
  521. selected_token_start_idx + subquery_len - 1))
  522. selected_token_indices.append(selected_token_start_idx +
  523. subquery_len - 1)
  524. selected_token_start_idx += subquery_len
  525. if sampling_params.seed is not None:
  526. seq_group_metadata.state.generator = torch.Generator(
  527. device=self.device).manual_seed(sampling_params.seed)
  528. else:
  529. num_seqs = len(seq_ids)
  530. selected_token_indices.extend(
  531. range(selected_token_start_idx,
  532. selected_token_start_idx + num_seqs))
  533. selected_token_start_idx += num_seqs
  534. categorized_sample_indices[
  535. sampling_params.sampling_type].extend(
  536. list(
  537. zip(
  538. range(
  539. categorized_sample_indices_start_idx,
  540. categorized_sample_indices_start_idx +
  541. num_seqs),
  542. range(
  543. categorized_sampled_token_indices_start_idx,
  544. categorized_sampled_token_indices_start_idx
  545. + num_seqs))))
  546. categorized_sample_indices_start_idx += num_seqs
  547. categorized_sampled_token_indices_start_idx += num_seqs
  548. if sampling_params.seed is not None:
  549. generators.append(seq_group_metadata.state.generator)
  550. selected_token_indices = async_tensor_h2d(selected_token_indices,
  551. dtype=torch.long,
  552. target_device=self.device,
  553. pin_memory=self.pin_memory)
  554. categorized_sample_indices = {
  555. t: maybe_expand_dim(
  556. async_tensor_h2d(seq_ids,
  557. dtype=torch.int,
  558. target_device=self.device,
  559. pin_memory=self.pin_memory), 2, 2)
  560. for t, seq_ids in categorized_sample_indices.items()
  561. }
  562. seq_data: Dict[int, SequenceData] = {}
  563. for seq_group_metadata in seq_group_metadata_list:
  564. seq_data.update(seq_group_metadata.seq_data)
  565. seq_persistence_data: Dict[int, dict] = {}
  566. for grp in seq_group_metadata_list:
  567. seq_persistence_data.update(grp.persistent_data)
  568. sampling_metadata = SamplingMetadata(
  569. seq_groups=seq_groups,
  570. seq_data=seq_data,
  571. prompt_lens=prompt_lens,
  572. selected_token_indices=selected_token_indices,
  573. categorized_sample_indices=categorized_sample_indices,
  574. generators=generators,
  575. persistent_metadata=PersistentMetadata(seq_persistence_data),
  576. )
  577. return sampling_metadata
  578. def prepare_input_tensors(
  579. self,
  580. seq_group_metadata_list: List[SequenceGroupMetadata],
  581. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  582. Set[LoRARequest], LoRAMapping, torch.Tensor]:
  583. if self.is_driver_worker:
  584. prefill_reqs = []
  585. decode_reqs = []
  586. for seq_group_meta in seq_group_metadata_list:
  587. if seq_group_meta.is_prompt:
  588. prefill_reqs.append(seq_group_meta)
  589. else:
  590. decode_reqs.append(seq_group_meta)
  591. # Prepare input tensors.
  592. (
  593. input_tokens,
  594. input_positions,
  595. prefill_attn_metadata,
  596. prompt_lens,
  597. subquery_lens,
  598. lora_index_mapping,
  599. lora_prompt_mapping,
  600. lora_requests,
  601. multi_modal_input,
  602. slot_mapping,
  603. ) = self._prepare_prompt(prefill_reqs)
  604. (
  605. decode_input_tokens,
  606. decode_input_positions,
  607. decode_attn_metadata,
  608. decode_lora_index_mapping,
  609. decode_lora_prompt_mapping,
  610. decode_lora_requests,
  611. decode_slot_mapping,
  612. ) = self._prepare_decode(decode_reqs)
  613. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  614. prompt_lens,
  615. subquery_lens)
  616. if not self.scheduler_config.chunked_prefill_enabled:
  617. assert (len(prefill_reqs) and len(decode_reqs)) == 0
  618. num_prefills = len(prompt_lens)
  619. num_prefill_tokens = len(input_tokens)
  620. num_decode_tokens = len(decode_input_tokens)
  621. # Coalesce tensors. Note that attn_metadata is currently not
  622. # coalesced for simplicity.
  623. input_tokens.extend(decode_input_tokens)
  624. input_positions.extend(decode_input_positions)
  625. slot_mapping.extend(decode_slot_mapping)
  626. lora_index_mapping.extend(decode_lora_index_mapping)
  627. lora_prompt_mapping.extend(decode_lora_prompt_mapping)
  628. lora_requests.update(decode_lora_requests)
  629. input_tokens = torch.tensor(input_tokens,
  630. dtype=torch.long,
  631. device=self.device)
  632. input_positions = torch.tensor(input_positions,
  633. dtype=torch.long,
  634. device=self.device)
  635. slot_mapping = torch.tensor(slot_mapping,
  636. dtype=torch.long,
  637. device=self.device)
  638. if self.lora_config:
  639. lora_mapping = LoRAMapping(
  640. lora_index_mapping,
  641. lora_prompt_mapping,
  642. )
  643. else:
  644. lora_mapping = None
  645. # Broadcast the metadata.
  646. # If batch contains both prefill and decode, it sends 2 broadcasts.
  647. # If it only contains 1 type, it triggers a single broadcast.
  648. if (prefill_attn_metadata is not None
  649. and decode_attn_metadata is not None):
  650. batch_type = BatchType.MIXED
  651. elif prefill_attn_metadata is not None:
  652. batch_type = BatchType.PREFILL
  653. else:
  654. batch_type = BatchType.DECODE
  655. metadata_dict = {
  656. "input_tokens": input_tokens,
  657. "input_positions": input_positions,
  658. "selected_token_indices":
  659. sampling_metadata.selected_token_indices,
  660. "lora_requests": lora_requests,
  661. "lora_mapping": lora_mapping,
  662. "multi_modal_input": multi_modal_input,
  663. "num_prefill_tokens": num_prefill_tokens,
  664. "num_decode_tokens": num_decode_tokens,
  665. "slot_mapping": slot_mapping,
  666. "num_prefills": num_prefills,
  667. "batch_type": batch_type,
  668. }
  669. if prefill_attn_metadata is not None:
  670. metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
  671. else:
  672. assert decode_attn_metadata is not None
  673. metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
  674. broadcast_tensor_dict(metadata_dict, src=0)
  675. # Broadcast decode attn metadata for mixed batch type.
  676. # The additional broadcast costs 300us overhead on 4 A10 GPUs.
  677. # We can potentially reduce the overhead by coelescing tensors.
  678. if batch_type == BatchType.MIXED:
  679. assert decode_attn_metadata is not None
  680. metadata_dict = decode_attn_metadata.asdict_zerocopy()
  681. broadcast_tensor_dict(metadata_dict, src=0)
  682. else:
  683. metadata_dict = broadcast_tensor_dict(src=0)
  684. input_tokens = metadata_dict.pop("input_tokens")
  685. input_positions = metadata_dict.pop("input_positions")
  686. slot_mapping = metadata_dict.pop("slot_mapping")
  687. num_prefills = metadata_dict.pop("num_prefills")
  688. selected_token_indices = metadata_dict.pop(
  689. "selected_token_indices")
  690. lora_mapping = metadata_dict.pop("lora_mapping")
  691. lora_requests = metadata_dict.pop("lora_requests")
  692. multi_modal_input = metadata_dict.pop("multi_modal_input")
  693. num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
  694. num_decode_tokens = metadata_dict.pop("num_decode_tokens")
  695. batch_type = metadata_dict.pop("batch_type")
  696. # Create an attention metadata.
  697. prefill_attn_metadata = None
  698. decode_attn_metadata = None
  699. if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
  700. prefill_attn_metadata = self.attn_backend.make_metadata(
  701. **metadata_dict)
  702. else:
  703. decode_attn_metadata = self.attn_backend.make_metadata(
  704. **metadata_dict)
  705. sampling_metadata = SamplingMetadata(
  706. seq_groups=None,
  707. seq_data=None,
  708. prompt_lens=None,
  709. selected_token_indices=selected_token_indices,
  710. categorized_sample_indices=None,
  711. generators=None,
  712. perform_sampling=False,
  713. )
  714. # if it is a mixed batch, decode attn_metadata is broadcasted
  715. # separately.
  716. if batch_type == BatchType.MIXED:
  717. metadata_dict = broadcast_tensor_dict(src=0)
  718. decode_attn_metadata = self.attn_backend.make_metadata(
  719. **metadata_dict)
  720. attn_metadata = AttentionMetadata(
  721. num_prefills=num_prefills,
  722. slot_mapping=slot_mapping,
  723. num_prefill_tokens=num_prefill_tokens,
  724. num_decode_tokens=num_decode_tokens,
  725. prefill_metadata=prefill_attn_metadata,
  726. decode_metadata=decode_attn_metadata,
  727. kv_cache_dtype=self.kv_cache_dtype,
  728. )
  729. return (input_tokens, input_positions, attn_metadata,
  730. sampling_metadata, lora_requests, lora_mapping,
  731. multi_modal_input)
  732. @torch.inference_mode()
  733. def execute_model(
  734. self,
  735. seq_group_metadata_list: List[SequenceGroupMetadata],
  736. kv_caches: List[torch.Tensor],
  737. ) -> Optional[SamplerOutput]:
  738. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  739. lora_requests, lora_mapping, multi_modal_input
  740. ) = self.prepare_input_tensors(seq_group_metadata_list)
  741. if self.lora_config:
  742. self.set_active_loras(lora_requests, lora_mapping)
  743. # Currently cuda graph is only supported by the decode phase.
  744. prefill_meta = attn_metadata.prefill_metadata
  745. decode_meta = attn_metadata.decode_metadata
  746. if prefill_meta is None and decode_meta.use_cuda_graph:
  747. graph_batch_size = input_tokens.shape[0]
  748. model_executable = self.graph_runners[graph_batch_size]
  749. else:
  750. model_executable = self.model
  751. execute_model_kwargs = {
  752. "input_ids": input_tokens,
  753. "positions": input_positions,
  754. "kv_caches": kv_caches,
  755. "attn_metadata": attn_metadata,
  756. }
  757. if self.vision_language_config:
  758. execute_model_kwargs.update({"image_input": multi_modal_input})
  759. hidden_states = model_executable(**execute_model_kwargs)
  760. # Compute the logits.
  761. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  762. # Only perform sampling in the driver worker.
  763. if not sampling_metadata.perform_sampling:
  764. return None
  765. # Sample the next token.
  766. output = self.model.sample(
  767. logits=logits,
  768. sampling_metadata=sampling_metadata,
  769. )
  770. return output
  771. @torch.inference_mode()
  772. def profile_run(self) -> None:
  773. # Enable top-k sampling to reflect the accurate memory usage.
  774. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  775. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  776. max_num_seqs = self.scheduler_config.max_num_seqs
  777. # This represents the maximum number of different requests
  778. # that will have unique loras, an therefore the max amount of memory
  779. # consumption create dummy lora request copies from the lora request
  780. # passed in, which contains a lora from the lora warmup path.
  781. dummy_lora_requests = []
  782. dummy_lora_requests_per_seq = []
  783. if self.lora_config:
  784. for idx in range(self.lora_config.max_loras):
  785. lora_id = idx + 1
  786. dummy_lora_request = LoRARequest(
  787. lora_name=f"warmup_{lora_id}",
  788. lora_int_id=lora_id,
  789. lora_local_path="/not/a/real/path",
  790. )
  791. self.lora_manager.add_dummy_lora(dummy_lora_request,
  792. rank=LORA_WARMUP_RANK)
  793. dummy_lora_requests.append(dummy_lora_request)
  794. dummy_lora_requests_per_seq = [
  795. dummy_lora_requests[idx % len(dummy_lora_requests)]
  796. for idx in range(max_num_seqs)
  797. ]
  798. # Profile memory usage with max_num_sequences sequences and the total
  799. # number of tokens equal to max_num_batched_tokens.
  800. seqs: List[SequenceGroupMetadata] = []
  801. # Additional GPU memory may be needed for vision encoding, which needs
  802. # to be accounted for when calculating the GPU blocks for
  803. # Aphrodite blocker manager.
  804. # To exercise the worst scenario for GPU memory consumption,
  805. # the number of seqs (batch_size) is chosen to maximize the number
  806. # of images processed.
  807. if self.vision_language_config:
  808. max_num_seqs = min(
  809. max_num_seqs,
  810. int(max_num_batched_tokens /
  811. self.vision_language_config.image_feature_size))
  812. for group_id in range(max_num_seqs):
  813. seq_len = (max_num_batched_tokens // max_num_seqs +
  814. (group_id < max_num_batched_tokens % max_num_seqs))
  815. seq_data, fake_multi_modal_input = _prepare_fake_inputs(
  816. seq_len, self.vision_language_config)
  817. seq = SequenceGroupMetadata(
  818. request_id=str(group_id),
  819. is_prompt=True,
  820. seq_data={group_id: seq_data},
  821. sampling_params=sampling_params,
  822. block_tables=None,
  823. persistent_data={},
  824. lora_request=dummy_lora_requests_per_seq[group_id]
  825. if dummy_lora_requests_per_seq else None,
  826. multi_modal_data=fake_multi_modal_input,
  827. )
  828. seqs.append(seq)
  829. # Run the model with the dummy inputs.
  830. num_layers = self.model_config.get_num_layers(self.parallel_config)
  831. kv_caches = [None] * num_layers
  832. self.execute_model(seqs, kv_caches)
  833. torch.cuda.synchronize()
  834. return
  835. def remove_all_loras(self) -> bool:
  836. if not self.lora_manager:
  837. raise RuntimeError("LoRA is not enabled.")
  838. return self.lora_manager.remove_all_loras()
  839. def set_active_loras(self, lora_requests: Set[LoRARequest],
  840. lora_mapping: LoRAMapping) -> None:
  841. if not self.lora_manager:
  842. raise RuntimeError("LoRA is not enabled.")
  843. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  844. def add_lora(self, lora_request: LoRARequest) -> bool:
  845. if not self.lora_manager:
  846. raise RuntimeError("LoRA is not enabled.")
  847. return self.lora_manager.add_lora(lora_request)
  848. def remove_lora(self, lora_id: int) -> bool:
  849. if not self.lora_manager:
  850. raise RuntimeError("LoRA is not enabled.")
  851. return self.lora_manager.remove_lora(lora_id)
  852. def list_loras(self) -> Set[int]:
  853. if not self.lora_manager:
  854. raise RuntimeError("LoRA is not enabled.")
  855. return self.lora_manager.list_loras()
  856. @torch.inference_mode()
  857. def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
  858. """Cuda graph capture a model.
  859. Note that CUDA graph's performance gain is negligible if number
  860. of batched tokens are larger than 200. And since CUDA graph
  861. requires fixed sized tensors, supporting large/variable batch
  862. size requires high GPU memory overhead. Thus, Aphrodite only captures
  863. decoding requests. Mixed batch (chunked prefill + decoding) or
  864. prefill requests are not captured.
  865. Since it is used for decoding-only, it assumes there's only 1 token
  866. per sequence in the batch.
  867. """
  868. # NOTE: This is a hack to ensure that the NCCL backend is never
  869. # deleted before the CUDA graphs.
  870. self.pynccl_backend = pynccl_utils.get_nccl_backend()
  871. assert not self.model_config.enforce_eager
  872. logger.info("Capturing the model for CUDA graphs. This may lead to "
  873. "unexpected consequences if the model is not static. To "
  874. "run the model in eager mode, set 'enforce_eager=True' or "
  875. "use '--enforce-eager' in the CLI.")
  876. logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
  877. "If you are running out of memory, consider decreasing "
  878. "`gpu_memory_utilization` or enforcing eager mode. "
  879. "You can also reduce the `max_num_seqs` as needed "
  880. "to decrease memory usage.")
  881. start_time = time.perf_counter()
  882. # Prepare dummy inputs. These will be reused for all batch sizes.
  883. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  884. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  885. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  886. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  887. slot_mapping.fill_(_PAD_SLOT_ID)
  888. context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  889. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  890. graph_batch_size = _get_graph_batch_size(
  891. self.scheduler_config.max_num_seqs)
  892. batch_size_capture_list = [
  893. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  894. ]
  895. # NOTE: There are 3 backends for all-reduce: custom all-reduce
  896. # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
  897. # either custom all-reduce kernel or pynccl. When not using CUDA
  898. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
  899. # We always prioritize using custom all-reduce kernel but fall back
  900. # to PyTorch or pynccl if it is disabled or not supported.
  901. with custom_all_reduce.capture():
  902. # NOTE: Capturing the largest batch size first may help reduce the
  903. # memory usage of CUDA graph.
  904. for batch_size in reversed(batch_size_capture_list):
  905. # Create dummy attn_metadata.
  906. decode_metadata = self.attn_backend.make_metadata(
  907. is_prompt=False,
  908. prompt_lens=None,
  909. prompt_lens_tensor=None,
  910. max_subquery_len=None,
  911. max_context_len=self.max_context_len_to_capture,
  912. max_prompt_len=None,
  913. subquery_start_loc=None,
  914. seq_start_loc=None,
  915. context_lens=context_lens[:batch_size],
  916. block_tables=block_tables[:batch_size],
  917. use_cuda_graph=True,
  918. )
  919. attn_metadata = AttentionMetadata(
  920. num_prefills=0,
  921. num_prefill_tokens=0,
  922. num_decode_tokens=batch_size,
  923. slot_mapping=slot_mapping[:batch_size],
  924. prefill_metadata=None,
  925. decode_metadata=decode_metadata,
  926. kv_cache_dtype=self.kv_cache_dtype,
  927. )
  928. if self.lora_config:
  929. lora_mapping = LoRAMapping(
  930. [0] * batch_size,
  931. [0] * batch_size,
  932. )
  933. self.set_active_loras(set(), lora_mapping)
  934. graph_runner = CUDAGraphRunner(self.model)
  935. graph_runner.capture(
  936. input_tokens[:batch_size],
  937. input_positions[:batch_size],
  938. kv_caches,
  939. attn_metadata,
  940. memory_pool=self.graph_memory_pool,
  941. )
  942. self.graph_memory_pool = graph_runner.graph.pool()
  943. self.graph_runners[batch_size] = graph_runner
  944. end_time = time.perf_counter()
  945. elapsed_time = end_time - start_time
  946. # This usually takes < 10 seconds.
  947. logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
  948. def __del__(self) -> None:
  949. # Delete the CUDA graphs before deleting the pynccl communicator.
  950. # NOTE: This is necessary because otherwise deadlocks can
  951. # happen.
  952. # FIXME: This is a bit hacky. Find a more robust solution.
  953. # TODO: when we get enough user feedback that pynccl is
  954. # more stable than cupy, we can remove this
  955. self.graph_runners.clear()
  956. self.pynccl_backend = None
  957. @property
  958. def vocab_size(self) -> int:
  959. return self.model_config.get_vocab_size()
  960. class CUDAGraphRunner:
  961. def __init__(self, model: nn.Module):
  962. self.model = model
  963. self.input_buffers: Dict[str, torch.Tensor] = {}
  964. self.output_buffers: Dict[str, torch.Tensor] = {}
  965. self._graph: Optional[torch.cuda.CUDAGraph] = None
  966. @property
  967. def graph(self):
  968. assert self._graph is not None
  969. return self._graph
  970. def capture(
  971. self,
  972. input_ids: torch.Tensor,
  973. positions: torch.Tensor,
  974. kv_caches: List[torch.Tensor],
  975. attn_metadata: AttentionMetadata,
  976. memory_pool,
  977. **kwargs,
  978. ) -> None:
  979. assert self._graph is None
  980. # Run the model once without capturing the graph.
  981. # This is to make sure that the captured graph does not include the
  982. # kernel launches for initial benchmarking (e.g., Triton autotune).
  983. with _maybe_pynccl():
  984. self.model(
  985. input_ids,
  986. positions,
  987. kv_caches,
  988. attn_metadata,
  989. **kwargs,
  990. )
  991. torch.cuda.synchronize()
  992. # Capture the graph.
  993. # NOTE: Python 3.8 does not support multi-line with statements.
  994. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
  995. self._graph = torch.cuda.CUDAGraph()
  996. with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
  997. with _maybe_pynccl():
  998. hidden_states = self.model(
  999. input_ids,
  1000. positions,
  1001. kv_caches,
  1002. attn_metadata,
  1003. **kwargs,
  1004. )
  1005. torch.cuda.synchronize()
  1006. # Save the input and output buffers.
  1007. self.input_buffers = {
  1008. "input_ids": input_ids,
  1009. "positions": positions,
  1010. "kv_caches": kv_caches,
  1011. "slot_mapping": attn_metadata.slot_mapping,
  1012. "context_lens": attn_metadata.decode_metadata.context_lens,
  1013. "block_tables": attn_metadata.decode_metadata.block_tables,
  1014. }
  1015. self.output_buffers = {"hidden_states": hidden_states}
  1016. return
  1017. def forward(
  1018. self,
  1019. input_ids: torch.Tensor,
  1020. positions: torch.Tensor,
  1021. kv_caches: List[torch.Tensor],
  1022. attn_metadata: AttentionMetadata,
  1023. **kwargs,
  1024. ) -> torch.Tensor:
  1025. # KV caches are fixed tensors, so we don't need to copy them.
  1026. del kv_caches
  1027. # Copy the input tensors to the input buffers.
  1028. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1029. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1030. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  1031. non_blocking=True)
  1032. self.input_buffers["context_lens"].copy_(
  1033. attn_metadata.decode_metadata.context_lens, non_blocking=True)
  1034. self.input_buffers["block_tables"].copy_(
  1035. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  1036. # Run the graph.
  1037. self.graph.replay()
  1038. # Return the output tensor.
  1039. return self.output_buffers["hidden_states"]
  1040. def __call__(self, *args, **kwargs):
  1041. return self.forward(*args, **kwargs)
  1042. @contextlib.contextmanager
  1043. def _maybe_pynccl():
  1044. if pynccl_utils.is_initialized(
  1045. ) and not custom_all_reduce.is_initialized():
  1046. with with_pynccl_for_all_reduce():
  1047. yield
  1048. else:
  1049. yield
  1050. def _get_graph_batch_size(batch_size: int) -> int:
  1051. """Returns the padded batch size given actual batch size.
  1052. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1053. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1054. """
  1055. if batch_size <= 2:
  1056. return batch_size
  1057. elif batch_size <= 4:
  1058. return 4
  1059. else:
  1060. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1061. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  1062. def _prepare_fake_inputs(
  1063. seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
  1064. """Prepare fake inputs for profile run."""
  1065. if vision_language_config:
  1066. prompt_tokens = [
  1067. vision_language_config.image_token_id
  1068. ] * vision_language_config.image_feature_size + [0] * (
  1069. seq_len - vision_language_config.image_feature_size)
  1070. fake_image_input = MultiModalData(
  1071. type=MultiModalData.Type.IMAGE,
  1072. data=torch.zeros(vision_language_config.image_input_shape,
  1073. dtype=torch.float16))
  1074. else:
  1075. prompt_tokens = [0] * seq_len
  1076. fake_image_input = None
  1077. return SequenceData(prompt_tokens), fake_image_input