model_runner.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226
  1. import contextlib
  2. import time
  3. from enum import IntEnum
  4. from typing import Dict, List, NamedTuple, Optional, Set, Tuple
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from loguru import logger
  9. from aphrodite.attention import (
  10. AttentionMetadata,
  11. AttentionMetadataPerStage,
  12. get_attn_backend,
  13. )
  14. from aphrodite.common.config import (
  15. DeviceConfig,
  16. LoRAConfig,
  17. ModelConfig,
  18. ParallelConfig,
  19. SchedulerConfig,
  20. VisionLanguageConfig,
  21. )
  22. from aphrodite.common.logger import get_loading_progress_bar
  23. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  24. from aphrodite.common.sequence import (
  25. MultiModalData,
  26. SamplerOutput,
  27. SequenceData,
  28. SequenceGroupMetadata,
  29. )
  30. from aphrodite.common.utils import (
  31. CudaMemoryProfiler,
  32. async_tensor_h2d,
  33. is_hip,
  34. is_pin_memory_available,
  35. make_tensor_with_pad,
  36. maybe_expand_dim,
  37. )
  38. from aphrodite.distributed import (
  39. broadcast_tensor_dict,
  40. get_tensor_model_parallel_world_size,
  41. with_pynccl_for_all_reduce,
  42. )
  43. from aphrodite.distributed.device_communicators import (
  44. custom_all_reduce,
  45. pynccl_utils,
  46. )
  47. from aphrodite.lora.layers import LoRAMapping
  48. from aphrodite.lora.request import LoRARequest
  49. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  50. from aphrodite.modeling import SamplingMetadata
  51. from aphrodite.modeling.loader import get_model
  52. from aphrodite.modeling.sampling_metadata import PersistentMetadata
  53. _PAD_SLOT_ID = -1
  54. LORA_WARMUP_RANK = 8
  55. _BATCH_SIZE_ALIGNMENT = 8
  56. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  57. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  58. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  59. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  60. ]
  61. class PreparePromptMetadata(NamedTuple):
  62. input_tokens: List[int]
  63. input_positions: List[int]
  64. attn_metadata: Optional[AttentionMetadataPerStage]
  65. prompt_lens: List[int]
  66. subquery_lens: List[int]
  67. lora_index_mapping: List[int]
  68. lora_prompt_mapping: List[int]
  69. lora_requests: Set[LoRARequest]
  70. multi_modal_input: Optional[torch.Tensor]
  71. slot_mapping: List[int]
  72. @classmethod
  73. def empty(cls):
  74. return PreparePromptMetadata(
  75. input_tokens=[],
  76. input_positions=[],
  77. attn_metadata=None,
  78. prompt_lens=[],
  79. subquery_lens=[],
  80. lora_index_mapping=[],
  81. lora_prompt_mapping=[],
  82. lora_requests=set(),
  83. multi_modal_input=None,
  84. slot_mapping=[],
  85. )
  86. class PrepareDecodeMetadata(NamedTuple):
  87. input_tokens: List[int]
  88. input_positions: List[int]
  89. attn_metadata: Optional[AttentionMetadata]
  90. lora_index_mapping: List[int]
  91. lora_prompt_mapping: List[int]
  92. lora_requests: Set[LoRARequest]
  93. slot_mapping: List[int]
  94. @classmethod
  95. def empty(cls):
  96. return PrepareDecodeMetadata(
  97. input_tokens=[],
  98. input_positions=[],
  99. attn_metadata=None,
  100. lora_index_mapping=[],
  101. lora_prompt_mapping=[],
  102. lora_requests=set(),
  103. slot_mapping=[],
  104. )
  105. # How batches are constructed.
  106. class BatchType(IntEnum):
  107. # Every batch is prefill.
  108. PREFILL = 0
  109. # Every batch is decode.
  110. DECODE = 1
  111. # Batch is a mixture of prefill and decode.
  112. MIXED = 2
  113. class ModelRunner:
  114. def __init__(
  115. self,
  116. model_config: ModelConfig,
  117. parallel_config: ParallelConfig,
  118. scheduler_config: SchedulerConfig,
  119. device_config: DeviceConfig,
  120. lora_config: Optional[LoRAConfig],
  121. kv_cache_dtype: Optional[str] = "auto",
  122. is_driver_worker: bool = False,
  123. vision_language_config: Optional[VisionLanguageConfig] = None,
  124. ):
  125. self.model_config = model_config
  126. self.parallel_config = parallel_config
  127. self.scheduler_config = scheduler_config
  128. self.lora_config = lora_config
  129. self.is_driver_worker = is_driver_worker
  130. # model_config can be None in tests/samplers/test_sampler.py.
  131. # FIXME: This is a hack to make the tests work. Refactor this.
  132. self.sliding_window = (model_config.get_sliding_window()
  133. if model_config is not None else None)
  134. self.device_config = (device_config
  135. if device_config is not None else DeviceConfig())
  136. self.device = self.device_config.device
  137. self.model = None
  138. self.block_size = None # Set after initial profiling.
  139. self.lora_manager = None
  140. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  141. self.graph_memory_pool = None # Set during graph capture.
  142. self.max_context_len_to_capture = (
  143. self.model_config.max_context_len_to_capture
  144. if self.model_config is not None else 0)
  145. # When using CUDA graph, the input block tables must be padded to
  146. # max_context_len_to_capture. However, creating the block table in
  147. # Python can be expensive. To optimize this, we cache the block table
  148. # in numpy and only copy the actual input content at every iteration.
  149. # The shape of the cached block table will be
  150. # (max batch size to capture, max context len to capture / block size).
  151. self.graph_block_tables = None # Set after initial profiling.
  152. self.pin_memory = is_pin_memory_available()
  153. self.kv_cache_dtype = kv_cache_dtype
  154. self.vision_language_config = vision_language_config
  155. self.attn_backend = get_attn_backend(
  156. self.model_config.dtype if model_config is not None else None)
  157. def load_model(self) -> None:
  158. with CudaMemoryProfiler() as m:
  159. self.model = get_model(
  160. self.model_config,
  161. self.device_config,
  162. lora_config=self.lora_config,
  163. vision_language_config=self.vision_language_config,
  164. parallel_config=self.parallel_config,
  165. scheduler_config=self.scheduler_config)
  166. self.model_memory_usage = m.consumed_memory
  167. tp = get_tensor_model_parallel_world_size()
  168. logger.info(
  169. "Model weights loaded. Memory usage: "
  170. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
  171. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  172. if self.lora_config:
  173. assert hasattr(self.model, "supported_lora_modules"
  174. ) and self.model.supported_lora_modules, (
  175. "Model does not support LoRA")
  176. assert hasattr(
  177. self.model,
  178. "embedding_modules"), "Model does not have embedding_modules"
  179. assert hasattr(self.model, "embedding_padding_modules"
  180. ), "Model does not have embedding_padding_modules"
  181. self.lora_manager = LRUCacheWorkerLoRAManager(
  182. self.scheduler_config.max_num_seqs,
  183. self.scheduler_config.max_num_batched_tokens, self.vocab_size,
  184. self.lora_config, self.device, self.model.embedding_modules,
  185. self.model.embedding_padding_modules)
  186. self.model = self.lora_manager.create_lora_manager(self.model)
  187. if self.kv_cache_dtype == "fp8" and is_hip():
  188. # Currently scaled KV cache is only enabled on ROCm
  189. if self.model_config.quantization_param_path is not None:
  190. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  191. self.model.load_kv_cache_scales(
  192. self.model_config.quantization_param_path)
  193. else:
  194. raise RuntimeError("Using FP8 KV cache and scaling "
  195. "factors provided but model "
  196. f"{self.model.__class__} does not "
  197. "support loading scaling factors.")
  198. else:
  199. logger.warning(
  200. "Using FP8 KV cache but no scaling factors "
  201. "provided. Defaulting to scaling factors of 1.0. "
  202. "This may lead to less accurate results!")
  203. elif self.model_config.quantization_param_path is not None:
  204. logger.warning("KV cache scaling factors provided, "
  205. "but the KV cache data type is not FP8. "
  206. "KV cache scaling factors will not be used.")
  207. def set_block_size(self, block_size: int) -> None:
  208. self.block_size = block_size
  209. self.graph_block_tables = np.zeros(
  210. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  211. dtype=np.int32)
  212. def get_max_block_per_batch(self) -> int:
  213. block_size = self.block_size
  214. return (self.max_context_len_to_capture + block_size - 1) // block_size
  215. def _prepare_prompt(
  216. self,
  217. seq_group_metadata_list: List[SequenceGroupMetadata],
  218. ) -> PreparePromptMetadata:
  219. input_tokens: List[int] = []
  220. input_positions: List[int] = []
  221. slot_mapping: List[int] = []
  222. lora_index_mapping: List[int] = []
  223. lora_prompt_mapping: List[int] = []
  224. lora_requests: Set[LoRARequest] = set()
  225. prompt_lens: List[int] = []
  226. context_lens: List[int] = []
  227. subquery_lens: List[int] = []
  228. prefix_block_tables: List[List[int]] = []
  229. multi_modal_input_list: List[torch.Tensor] = []
  230. if len(seq_group_metadata_list) == 0:
  231. return PreparePromptMetadata.empty()
  232. for seq_group_metadata in seq_group_metadata_list:
  233. assert seq_group_metadata.is_prompt
  234. seq_ids = list(seq_group_metadata.seq_data.keys())
  235. assert len(seq_ids) == 1
  236. seq_id = seq_ids[0]
  237. computed_block_nums = seq_group_metadata.computed_block_nums
  238. if (self.scheduler_config is not None
  239. and self.scheduler_config.chunked_prefill_enabled
  240. and not (computed_block_nums is None
  241. or computed_block_nums == [])):
  242. raise RuntimeError(
  243. "chunked prefill cannot be used with prefix caching "
  244. "now.")
  245. token_chunk_size = seq_group_metadata.token_chunk_size
  246. seq_data = seq_group_metadata.seq_data[seq_id]
  247. computed_len = seq_data.get_num_computed_tokens()
  248. # We should use get_len here because in case of preemption
  249. # it contains output tokens.
  250. prefill_end = min(seq_data.get_len(),
  251. computed_len + token_chunk_size)
  252. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
  253. prompt_len = prefill_end
  254. prompt_lens.append(prompt_len)
  255. # NOTE: This only works for oooooooxxx style attention.
  256. if computed_block_nums is not None and len(
  257. computed_block_nums) > 0 and self.sliding_window is None:
  258. # Prefix is not supported with sliding_window
  259. computed_len = len(computed_block_nums) * self.block_size
  260. prompt_tokens = prompt_tokens[computed_len:]
  261. prefix_block_tables.append(computed_block_nums)
  262. elif self.scheduler_config.chunked_prefill_enabled:
  263. if seq_group_metadata.block_tables is not None:
  264. # Prefill has chunked before.
  265. block_table = seq_group_metadata.block_tables[seq_id]
  266. prefix_block_tables.append(block_table)
  267. else:
  268. # The first prefill.
  269. prefix_block_tables.append([])
  270. else:
  271. prefix_block_tables.append([])
  272. # Right now, prefill start is always 0. However, this
  273. # assumption can be changed once chunked prefill is introduced.
  274. assert computed_len == 0
  275. # actual prompt lens
  276. context_lens.append(computed_len)
  277. subquery_lens.append(prompt_len - computed_len)
  278. input_tokens.extend(prompt_tokens)
  279. # NOTE: Here we assume that the first token in the prompt
  280. # is always the first token in the sequence.
  281. input_positions.extend(list(range(computed_len, prefill_end)))
  282. lora_id = seq_group_metadata.lora_int_id
  283. if lora_id > 0:
  284. lora_requests.add(seq_group_metadata.lora_request)
  285. lora_index_mapping += [lora_id] * (prompt_len - computed_len)
  286. lora_prompt_mapping.extend(
  287. [lora_id] *
  288. (prompt_len - computed_len
  289. if seq_group_metadata.sampling_params.prompt_logprobs else 1))
  290. if seq_group_metadata.multi_modal_data:
  291. multi_modal_input_list.append(
  292. seq_group_metadata.multi_modal_data.data)
  293. if seq_group_metadata.block_tables is None:
  294. # During memory profiling, the block tables are not initialized
  295. # yet. In this case, we just use a dummy slot mapping.
  296. slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
  297. continue
  298. # Compute the slot mapping.
  299. block_table = seq_group_metadata.block_tables[seq_id]
  300. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  301. # where start_idx is max(0, prompt_len - sliding_window).
  302. # For example, if the prompt len is 10, sliding window is 8, and
  303. # block size is 4, the first two tokens are masked and the slot
  304. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  305. start_idx = 0
  306. if self.sliding_window is not None:
  307. assert computed_len == 0, (
  308. "Prefix caching is currently not supported with "
  309. "sliding window attention")
  310. start_idx = max(0, prompt_len - self.sliding_window)
  311. for i in range(computed_len, prefill_end):
  312. if i < start_idx:
  313. slot_mapping.append(_PAD_SLOT_ID)
  314. continue
  315. block_number = block_table[i // self.block_size]
  316. block_offset = i % self.block_size
  317. slot = block_number * self.block_size + block_offset
  318. slot_mapping.append(slot)
  319. max_subquery_len = max(subquery_lens)
  320. max_prompt_len = max(prompt_lens)
  321. assert max_subquery_len > 0
  322. context_lens_tensor = torch.tensor(context_lens,
  323. dtype=torch.int,
  324. device=self.device)
  325. if multi_modal_input_list:
  326. assert self.vision_language_config, (
  327. "Multi-modal inputs are only supported by "
  328. "vision language models.")
  329. multi_modal_input = torch.cat(multi_modal_input_list,
  330. dim=0).to(self.device)
  331. else:
  332. multi_modal_input = None
  333. # Prepare prefix block tables
  334. max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
  335. block_tables = make_tensor_with_pad(
  336. prefix_block_tables,
  337. max_len=max_prompt_block_table_len,
  338. pad=0,
  339. dtype=torch.int,
  340. device=self.device,
  341. )
  342. # Query length can be shorter than key (i.e., prompt) when prefill
  343. # is chunked or prefix cached.
  344. subquery_lens_tensor = torch.tensor(subquery_lens,
  345. dtype=torch.long,
  346. device=self.device)
  347. subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
  348. dtype=torch.int32,
  349. device=self.device)
  350. prompt_lens_tensor = torch.tensor(prompt_lens,
  351. dtype=torch.long,
  352. device=self.device)
  353. seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
  354. dtype=torch.int32,
  355. device=self.device)
  356. torch.cumsum(subquery_lens_tensor,
  357. dim=0,
  358. dtype=subquery_start_loc.dtype,
  359. out=subquery_start_loc[1:])
  360. torch.cumsum(prompt_lens_tensor,
  361. dim=0,
  362. dtype=seq_start_loc.dtype,
  363. out=seq_start_loc[1:])
  364. attn_metadata = self.attn_backend.make_metadata(
  365. is_prompt=True,
  366. prompt_lens=prompt_lens,
  367. prompt_lens_tensor=prompt_lens_tensor,
  368. max_subquery_len=max_subquery_len,
  369. max_context_len=None,
  370. max_prompt_len=max_prompt_len,
  371. subquery_start_loc=subquery_start_loc,
  372. seq_start_loc=seq_start_loc,
  373. context_lens=context_lens_tensor,
  374. block_tables=block_tables,
  375. use_cuda_graph=False,
  376. )
  377. return PreparePromptMetadata(
  378. input_tokens=input_tokens,
  379. input_positions=input_positions,
  380. attn_metadata=attn_metadata,
  381. prompt_lens=prompt_lens,
  382. subquery_lens=subquery_lens,
  383. lora_index_mapping=lora_index_mapping,
  384. lora_prompt_mapping=lora_prompt_mapping,
  385. lora_requests=lora_requests,
  386. multi_modal_input=multi_modal_input,
  387. slot_mapping=slot_mapping,
  388. )
  389. def _prepare_decode(
  390. self,
  391. seq_group_metadata_list: List[SequenceGroupMetadata],
  392. ) -> PrepareDecodeMetadata:
  393. input_tokens: List[int] = []
  394. input_positions: List[int] = []
  395. slot_mapping: List[int] = []
  396. context_lens: List[int] = []
  397. block_tables: List[List[int]] = []
  398. lora_index_mapping: List[int] = []
  399. lora_prompt_mapping: List[int] = []
  400. lora_requests: Set[LoRARequest] = set()
  401. if len(seq_group_metadata_list) == 0:
  402. return PrepareDecodeMetadata.empty()
  403. for seq_group_metadata in seq_group_metadata_list:
  404. assert not seq_group_metadata.is_prompt
  405. assert seq_group_metadata.token_chunk_size == 1
  406. seq_ids = list(seq_group_metadata.seq_data.keys())
  407. lora_id = seq_group_metadata.lora_int_id
  408. if lora_id > 0:
  409. lora_requests.add(seq_group_metadata.lora_request)
  410. for seq_id in seq_ids:
  411. seq_data = seq_group_metadata.seq_data[seq_id]
  412. generation_token = seq_data.get_last_token_id()
  413. input_tokens.append(generation_token)
  414. seq_len = seq_data.get_len()
  415. position = seq_len - 1
  416. input_positions.append(position)
  417. context_len = seq_len if self.sliding_window is None else min(
  418. seq_len, self.sliding_window)
  419. context_lens.append(context_len)
  420. block_table = seq_group_metadata.block_tables[seq_id]
  421. block_number = block_table[position // self.block_size]
  422. block_offset = position % self.block_size
  423. slot = block_number * self.block_size + block_offset
  424. slot_mapping.append(slot)
  425. lora_index_mapping.append(lora_id)
  426. lora_prompt_mapping.append(lora_id)
  427. if self.sliding_window is not None:
  428. sliding_window_blocks = (self.sliding_window //
  429. self.block_size)
  430. block_table = block_table[-sliding_window_blocks:]
  431. block_tables.append(block_table)
  432. # Aphrodite uses CUDA graph only for decoding requests.
  433. # See `capture_model` API for more details.
  434. # For decoding requests, batch_size == input_tokens.
  435. batch_size = len(input_tokens)
  436. max_context_len = max(context_lens)
  437. use_captured_graph = (
  438. not self.model_config.enforce_eager
  439. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  440. and max_context_len <= self.max_context_len_to_capture)
  441. if use_captured_graph:
  442. graph_batch_size = _get_graph_batch_size(batch_size)
  443. assert graph_batch_size >= batch_size
  444. for _ in range(graph_batch_size - batch_size):
  445. input_tokens.append(0)
  446. input_positions.append(0)
  447. slot_mapping.append(_PAD_SLOT_ID)
  448. context_lens.append(1)
  449. block_tables.append([])
  450. lora_index_mapping.append(0)
  451. batch_size = graph_batch_size
  452. context_lens = torch.tensor(context_lens,
  453. dtype=torch.int,
  454. device=self.device)
  455. if use_captured_graph:
  456. # When using cuda-graph all these tensors should be
  457. # padded.
  458. assert context_lens.shape[0] == len(input_tokens)
  459. assert context_lens.shape[0] == len(input_positions)
  460. assert context_lens.shape[0] == len(slot_mapping)
  461. # The shape of graph_block_tables is
  462. # [max batch size, max context len // block size].
  463. input_block_tables = self.graph_block_tables[:batch_size]
  464. for i, block_table in enumerate(block_tables):
  465. if block_table:
  466. input_block_tables[i, :len(block_table)] = block_table
  467. block_tables = torch.tensor(input_block_tables, device=self.device)
  468. else:
  469. max_block_table_len = max(
  470. len(block_table) for block_table in block_tables)
  471. block_tables = make_tensor_with_pad(
  472. block_tables,
  473. max_len=max_block_table_len,
  474. pad=0,
  475. dtype=torch.int,
  476. device=self.device,
  477. )
  478. attn_metadata = self.attn_backend.make_metadata(
  479. is_prompt=False,
  480. prompt_lens=None,
  481. prompt_lens_tensor=None,
  482. max_subquery_len=None,
  483. max_context_len=max_context_len,
  484. max_prompt_len=None,
  485. subquery_start_loc=None,
  486. seq_start_loc=None,
  487. context_lens=context_lens,
  488. block_tables=block_tables,
  489. use_cuda_graph=use_captured_graph,
  490. )
  491. return PrepareDecodeMetadata(
  492. input_tokens=input_tokens,
  493. input_positions=input_positions,
  494. attn_metadata=attn_metadata,
  495. lora_index_mapping=lora_index_mapping,
  496. lora_prompt_mapping=lora_prompt_mapping,
  497. lora_requests=lora_requests,
  498. slot_mapping=slot_mapping,
  499. )
  500. def _prepare_sample(
  501. self,
  502. seq_group_metadata_list: List[SequenceGroupMetadata],
  503. prompt_lens: List[int],
  504. subquery_lens: Optional[List[int]],
  505. ) -> SamplingMetadata:
  506. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  507. selected_token_indices: List[int] = []
  508. generators: List[torch.Generator] = []
  509. selected_token_start_idx = 0
  510. categorized_sample_indices = {t: [] for t in SamplingType}
  511. categorized_sample_indices_start_idx = 0
  512. categorized_sampled_token_indices_start_idx = 0
  513. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  514. seq_ids = list(seq_group_metadata.seq_data.keys())
  515. sampling_params = seq_group_metadata.sampling_params
  516. seq_groups.append((seq_ids, sampling_params))
  517. if seq_group_metadata.is_prompt:
  518. assert len(seq_ids) == 1
  519. assert subquery_lens is not None
  520. subquery_len = subquery_lens[i]
  521. if sampling_params.prompt_logprobs is not None:
  522. # NOTE: prompt token positions do not need sample, skip
  523. categorized_sample_indices_start_idx += subquery_len - 1
  524. categorized_sample_indices[
  525. sampling_params.sampling_type].append([
  526. categorized_sample_indices_start_idx,
  527. categorized_sampled_token_indices_start_idx
  528. ])
  529. categorized_sample_indices_start_idx += 1
  530. categorized_sampled_token_indices_start_idx += 1
  531. if sampling_params.prompt_logprobs is not None:
  532. selected_token_indices.extend(
  533. range(selected_token_start_idx,
  534. selected_token_start_idx + subquery_len - 1))
  535. selected_token_indices.append(selected_token_start_idx +
  536. subquery_len - 1)
  537. selected_token_start_idx += subquery_len
  538. if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
  539. assert sampling_params.seed is not None
  540. seq_group_metadata.state.generator = torch.Generator(
  541. device=self.device).manual_seed(sampling_params.seed)
  542. else:
  543. num_seqs = len(seq_ids)
  544. selected_token_indices.extend(
  545. range(selected_token_start_idx,
  546. selected_token_start_idx + num_seqs))
  547. selected_token_start_idx += num_seqs
  548. categorized_sample_indices[
  549. sampling_params.sampling_type].extend(
  550. zip(
  551. range(
  552. categorized_sample_indices_start_idx,
  553. categorized_sample_indices_start_idx +
  554. num_seqs),
  555. range(
  556. categorized_sampled_token_indices_start_idx,
  557. categorized_sampled_token_indices_start_idx +
  558. num_seqs)))
  559. categorized_sample_indices_start_idx += num_seqs
  560. categorized_sampled_token_indices_start_idx += num_seqs
  561. if seq_group_metadata.state.generator is not None:
  562. generators.append(seq_group_metadata.state.generator)
  563. selected_token_indices = async_tensor_h2d(selected_token_indices,
  564. dtype=torch.long,
  565. target_device=self.device,
  566. pin_memory=self.pin_memory)
  567. categorized_sample_indices = {
  568. t: maybe_expand_dim(
  569. async_tensor_h2d(seq_ids,
  570. dtype=torch.int,
  571. target_device=self.device,
  572. pin_memory=self.pin_memory), 2, 2)
  573. for t, seq_ids in categorized_sample_indices.items()
  574. }
  575. seq_data: Dict[int, SequenceData] = {}
  576. for seq_group_metadata in seq_group_metadata_list:
  577. seq_data.update(seq_group_metadata.seq_data)
  578. seq_persistence_data: Dict[int, dict] = {}
  579. for grp in seq_group_metadata_list:
  580. seq_persistence_data.update(grp.persistent_data)
  581. sampling_metadata = SamplingMetadata(
  582. seq_groups=seq_groups,
  583. seq_data=seq_data,
  584. prompt_lens=prompt_lens,
  585. selected_token_indices=selected_token_indices,
  586. categorized_sample_indices=categorized_sample_indices,
  587. generators=generators,
  588. persistent_metadata=PersistentMetadata(seq_persistence_data),
  589. )
  590. return sampling_metadata
  591. def prepare_input_tensors(
  592. self,
  593. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  594. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  595. Set[int], LoRAMapping, torch.Tensor]:
  596. if self.is_driver_worker:
  597. prefill_reqs = []
  598. decode_reqs = []
  599. for seq_group_meta in seq_group_metadata_list:
  600. if seq_group_meta.is_prompt:
  601. prefill_reqs.append(seq_group_meta)
  602. else:
  603. decode_reqs.append(seq_group_meta)
  604. # Prepare input tensors.
  605. (
  606. input_tokens,
  607. input_positions,
  608. prefill_attn_metadata,
  609. prompt_lens,
  610. subquery_lens,
  611. lora_index_mapping,
  612. lora_prompt_mapping,
  613. lora_requests,
  614. multi_modal_input,
  615. slot_mapping,
  616. ) = self._prepare_prompt(prefill_reqs)
  617. (
  618. decode_input_tokens,
  619. decode_input_positions,
  620. decode_attn_metadata,
  621. decode_lora_index_mapping,
  622. decode_lora_prompt_mapping,
  623. decode_lora_requests,
  624. decode_slot_mapping,
  625. ) = self._prepare_decode(decode_reqs)
  626. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  627. prompt_lens,
  628. subquery_lens)
  629. if not self.scheduler_config.chunked_prefill_enabled:
  630. assert (len(prefill_reqs) and len(decode_reqs)) == 0
  631. num_prefills = len(prompt_lens)
  632. num_prefill_tokens = len(input_tokens)
  633. num_decode_tokens = len(decode_input_tokens)
  634. # Coalesce tensors. Note that attn_metadata is currently not
  635. # coalesced for simplicity.
  636. input_tokens.extend(decode_input_tokens)
  637. input_positions.extend(decode_input_positions)
  638. slot_mapping.extend(decode_slot_mapping)
  639. lora_index_mapping.extend(decode_lora_index_mapping)
  640. lora_prompt_mapping.extend(decode_lora_prompt_mapping)
  641. lora_requests.update(decode_lora_requests)
  642. input_tokens = torch.tensor(input_tokens,
  643. dtype=torch.long,
  644. device=self.device)
  645. input_positions = torch.tensor(input_positions,
  646. dtype=torch.long,
  647. device=self.device)
  648. slot_mapping = torch.tensor(slot_mapping,
  649. dtype=torch.long,
  650. device=self.device)
  651. if self.lora_config:
  652. lora_mapping = LoRAMapping(
  653. lora_index_mapping,
  654. lora_prompt_mapping,
  655. )
  656. else:
  657. lora_mapping = None
  658. # Broadcast the metadata.
  659. # If batch contains both prefill and decode, it sends 2 broadcasts.
  660. # If it only contains 1 type, it triggers a single broadcast.
  661. if (prefill_attn_metadata is not None
  662. and decode_attn_metadata is not None):
  663. batch_type = BatchType.MIXED
  664. elif prefill_attn_metadata is not None:
  665. batch_type = BatchType.PREFILL
  666. else:
  667. batch_type = BatchType.DECODE
  668. metadata_dict = {
  669. "input_tokens": input_tokens,
  670. "input_positions": input_positions,
  671. "selected_token_indices":
  672. sampling_metadata.selected_token_indices,
  673. "lora_requests": lora_requests,
  674. "lora_mapping": lora_mapping,
  675. "multi_modal_input": multi_modal_input,
  676. "num_prefill_tokens": num_prefill_tokens,
  677. "num_decode_tokens": num_decode_tokens,
  678. "slot_mapping": slot_mapping,
  679. "num_prefills": num_prefills,
  680. "batch_type": batch_type,
  681. }
  682. if prefill_attn_metadata is not None:
  683. metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
  684. else:
  685. metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
  686. broadcast_tensor_dict(metadata_dict, src=0)
  687. # Broadcast decode attn metadata for mixed batch type.
  688. # The additional broadcast costs 300us overhead on 4 A10 GPUs.
  689. # We can potentially reduce the overhead by coelescing tensors.
  690. if batch_type == BatchType.MIXED:
  691. assert decode_attn_metadata is not None
  692. metadata_dict = decode_attn_metadata.asdict_zerocopy()
  693. broadcast_tensor_dict(metadata_dict, src=0)
  694. else:
  695. metadata_dict = broadcast_tensor_dict(src=0)
  696. input_tokens = metadata_dict.pop("input_tokens")
  697. input_positions = metadata_dict.pop("input_positions")
  698. slot_mapping = metadata_dict.pop("slot_mapping")
  699. num_prefills = metadata_dict.pop("num_prefills")
  700. selected_token_indices = metadata_dict.pop(
  701. "selected_token_indices")
  702. lora_mapping = metadata_dict.pop("lora_mapping")
  703. lora_requests = metadata_dict.pop("lora_requests")
  704. multi_modal_input = metadata_dict.pop("multi_modal_input")
  705. num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
  706. num_decode_tokens = metadata_dict.pop("num_decode_tokens")
  707. batch_type = metadata_dict.pop("batch_type")
  708. # Create an attention metadata.
  709. prefill_attn_metadata = None
  710. decode_attn_metadata = None
  711. if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
  712. prefill_attn_metadata = self.attn_backend.make_metadata(
  713. **metadata_dict)
  714. else:
  715. decode_attn_metadata = self.attn_backend.make_metadata(
  716. **metadata_dict)
  717. sampling_metadata = SamplingMetadata(
  718. seq_groups=None,
  719. seq_data=None,
  720. prompt_lens=None,
  721. selected_token_indices=selected_token_indices,
  722. categorized_sample_indices=None,
  723. generators=None,
  724. perform_sampling=False,
  725. )
  726. # if it is a mixed batch, decode attn_metadata is broadcasted
  727. # separately.
  728. if batch_type == BatchType.MIXED:
  729. metadata_dict = broadcast_tensor_dict(src=0)
  730. decode_attn_metadata = self.attn_backend.make_metadata(
  731. **metadata_dict)
  732. attn_metadata = AttentionMetadata(
  733. num_prefills=num_prefills,
  734. slot_mapping=slot_mapping,
  735. num_prefill_tokens=num_prefill_tokens,
  736. num_decode_tokens=num_decode_tokens,
  737. prefill_metadata=prefill_attn_metadata,
  738. decode_metadata=decode_attn_metadata,
  739. kv_cache_dtype=self.kv_cache_dtype,
  740. )
  741. return (input_tokens, input_positions, attn_metadata,
  742. sampling_metadata, lora_requests, lora_mapping,
  743. multi_modal_input)
  744. @torch.inference_mode()
  745. def execute_model(
  746. self,
  747. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  748. kv_caches: List[torch.Tensor],
  749. ) -> Optional[SamplerOutput]:
  750. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  751. lora_requests, lora_mapping, multi_modal_input
  752. ) = self.prepare_input_tensors(seq_group_metadata_list)
  753. if self.lora_config:
  754. self.set_active_loras(lora_requests, lora_mapping)
  755. # Currently cuda graph is only supported by the decode phase.
  756. prefill_meta = attn_metadata.prefill_metadata
  757. decode_meta = attn_metadata.decode_metadata
  758. if prefill_meta is None and decode_meta.use_cuda_graph:
  759. graph_batch_size = input_tokens.shape[0]
  760. model_executable = self.graph_runners[graph_batch_size]
  761. else:
  762. model_executable = self.model
  763. execute_model_kwargs = {
  764. "input_ids": input_tokens,
  765. "positions": input_positions,
  766. "kv_caches": kv_caches,
  767. "attn_metadata": attn_metadata,
  768. }
  769. if self.vision_language_config:
  770. execute_model_kwargs.update({"image_input": multi_modal_input})
  771. hidden_states = model_executable(**execute_model_kwargs)
  772. # Compute the logits.
  773. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  774. # Only perform sampling in the driver worker.
  775. if not sampling_metadata.perform_sampling:
  776. return None
  777. # Sample the next token.
  778. output = self.model.sample(
  779. logits=logits,
  780. sampling_metadata=sampling_metadata,
  781. )
  782. return output
  783. @torch.inference_mode()
  784. def profile_run(self) -> None:
  785. # Enable top-k sampling to reflect the accurate memory usage.
  786. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  787. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  788. max_num_seqs = self.scheduler_config.max_num_seqs
  789. # This represents the maximum number of different requests
  790. # that will have unique loras, an therefore the max amount of memory
  791. # consumption create dummy lora request copies from the lora request
  792. # passed in, which contains a lora from the lora warmup path.
  793. dummy_lora_requests = []
  794. dummy_lora_requests_per_seq = []
  795. if self.lora_config:
  796. for idx in range(self.lora_config.max_loras):
  797. lora_id = idx + 1
  798. dummy_lora_request = LoRARequest(
  799. lora_name=f"warmup_{lora_id}",
  800. lora_int_id=lora_id,
  801. lora_local_path="/not/a/real/path",
  802. )
  803. self.lora_manager.add_dummy_lora(dummy_lora_request,
  804. rank=LORA_WARMUP_RANK)
  805. dummy_lora_requests.append(dummy_lora_request)
  806. dummy_lora_requests_per_seq = [
  807. dummy_lora_requests[idx % len(dummy_lora_requests)]
  808. for idx in range(max_num_seqs)
  809. ]
  810. # Profile memory usage with max_num_sequences sequences and the total
  811. # number of tokens equal to max_num_batched_tokens.
  812. seqs: List[SequenceGroupMetadata] = []
  813. # Additional GPU memory may be needed for vision encoding, which needs
  814. # to be accounted for when calculating the GPU blocks for
  815. # Aphrodite blocker manager.
  816. # To exercise the worst scenario for GPU memory consumption,
  817. # the number of seqs (batch_size) is chosen to maximize the number
  818. # of images processed.
  819. if self.vision_language_config:
  820. max_num_seqs = min(
  821. max_num_seqs,
  822. int(max_num_batched_tokens /
  823. self.vision_language_config.image_feature_size))
  824. for group_id in range(max_num_seqs):
  825. seq_len = (max_num_batched_tokens // max_num_seqs +
  826. (group_id < max_num_batched_tokens % max_num_seqs))
  827. seq_data, fake_multi_modal_input = _prepare_fake_inputs(
  828. seq_len, self.vision_language_config)
  829. seq = SequenceGroupMetadata(
  830. request_id=str(group_id),
  831. is_prompt=True,
  832. seq_data={group_id: seq_data},
  833. sampling_params=sampling_params,
  834. block_tables=None,
  835. persistent_data={},
  836. lora_request=dummy_lora_requests_per_seq[group_id]
  837. if dummy_lora_requests_per_seq else None,
  838. multi_modal_data=fake_multi_modal_input,
  839. )
  840. seqs.append(seq)
  841. # Run the model with the dummy inputs.
  842. num_layers = self.model_config.get_num_layers(self.parallel_config)
  843. kv_caches = [None] * num_layers
  844. self.execute_model(seqs, kv_caches)
  845. torch.cuda.synchronize()
  846. return
  847. def remove_all_loras(self) -> bool:
  848. if not self.lora_manager:
  849. raise RuntimeError("LoRA is not enabled.")
  850. return self.lora_manager.remove_all_loras()
  851. def set_active_loras(self, lora_requests: List[LoRARequest],
  852. lora_mapping: LoRAMapping) -> None:
  853. if not self.lora_manager:
  854. raise RuntimeError("LoRA is not enabled.")
  855. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  856. def add_lora(self, lora_request: LoRARequest) -> bool:
  857. if not self.lora_manager:
  858. raise RuntimeError("LoRA is not enabled.")
  859. return self.lora_manager.add_lora(lora_request)
  860. def remove_lora(self, lora_id: int) -> bool:
  861. if not self.lora_manager:
  862. raise RuntimeError("LoRA is not enabled.")
  863. return self.lora_manager.remove_lora(lora_id)
  864. def list_loras(self) -> Set[int]:
  865. if not self.lora_manager:
  866. raise RuntimeError("LoRA is not enabled.")
  867. return self.lora_manager.list_loras()
  868. @torch.inference_mode()
  869. def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
  870. """Cuda graph capture a model.
  871. Note that CUDA graph's performance gain is negligible if number
  872. of batched tokens are larger than 200. And since CUDA graph
  873. requires fixed sized tensors, supporting large/variable batch
  874. size requires high GPU memory overhead. Thus, Aphrodite only captures
  875. decoding requests. Mixed batch (chunked prefill + decoding) or
  876. prefill requests are not captured.
  877. Since it is used for decoding-only, it assumes there's only 1 token
  878. per sequence in the batch.
  879. """
  880. # NOTE: This is a hack to ensure that the NCCL backend is never
  881. # deleted before the CUDA graphs.
  882. self.pynccl_backend = pynccl_utils.get_nccl_backend()
  883. assert not self.model_config.enforce_eager
  884. logger.info("Capturing the model for CUDA graphs. This may lead to "
  885. "unexpected consequences if the model is not static. To "
  886. "run the model in eager mode, set 'enforce_eager=True' or "
  887. "use '--enforce-eager' in the CLI.")
  888. logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
  889. "per GPU. If you are running out of memory, consider "
  890. "decreasing `gpu_memory_utilization` or enforcing "
  891. "eager mode.")
  892. start_time = time.perf_counter()
  893. # Prepare dummy inputs. These will be reused for all batch sizes.
  894. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  895. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  896. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  897. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  898. slot_mapping.fill_(_PAD_SLOT_ID)
  899. context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  900. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  901. graph_batch_size = _get_graph_batch_size(
  902. self.scheduler_config.max_num_seqs)
  903. batch_size_capture_list = [
  904. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  905. ]
  906. # NOTE: There are 3 backends for all-reduce: custom all-reduce
  907. # kernel, PyNCCL, and PyTorch NCCL. When using CUDA graph, we use
  908. # either custom all-reduce kernel or PyNCCL. When not using CUDA
  909. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
  910. # We always prioritize using custom all-reduce kernel but fall back
  911. # to PyTorch or PyNCCL if it is disabled or not supported.
  912. # Initialize a new progress bar
  913. progress = get_loading_progress_bar()
  914. task = progress.add_task("[cyan]Capturing graph...",
  915. total=len(batch_size_capture_list))
  916. with progress, custom_all_reduce.capture():
  917. for batch_size in reversed(batch_size_capture_list):
  918. # Create dummy attn_metadata.
  919. decode_metadata = self.attn_backend.make_metadata(
  920. is_prompt=False,
  921. prompt_lens=None,
  922. prompt_lens_tensor=None,
  923. max_subquery_len=None,
  924. max_context_len=self.max_context_len_to_capture,
  925. max_prompt_len=None,
  926. subquery_start_loc=None,
  927. seq_start_loc=None,
  928. context_lens=context_lens[:batch_size],
  929. block_tables=block_tables[:batch_size],
  930. use_cuda_graph=True,
  931. )
  932. attn_metadata = AttentionMetadata(
  933. num_prefills=0,
  934. num_prefill_tokens=0,
  935. num_decode_tokens=batch_size,
  936. slot_mapping=slot_mapping[:batch_size],
  937. prefill_metadata=None,
  938. decode_metadata=decode_metadata,
  939. kv_cache_dtype=self.kv_cache_dtype,
  940. )
  941. if self.lora_config:
  942. lora_mapping = LoRAMapping(
  943. [0] * batch_size,
  944. [0] * batch_size,
  945. )
  946. self.set_active_loras(set(), lora_mapping)
  947. graph_runner = CUDAGraphRunner(self.model)
  948. graph_runner.capture(
  949. input_tokens[:batch_size],
  950. input_positions[:batch_size],
  951. kv_caches,
  952. attn_metadata,
  953. memory_pool=self.graph_memory_pool,
  954. )
  955. self.graph_memory_pool = graph_runner.graph.pool()
  956. self.graph_runners[batch_size] = graph_runner
  957. # Update the progress bar
  958. progress.update(task, advance=1)
  959. end_time = time.perf_counter()
  960. elapsed_time = end_time - start_time
  961. # This usually takes < 10 seconds.
  962. logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
  963. def __del__(self) -> None:
  964. # Delete the CUDA graphs before deleting the pynccl communicator.
  965. # NOTE: This is necessary because otherwise deadlocks can
  966. # happen.
  967. # FIXME: This is a bit hacky. Find a more robust solution.
  968. # TODO: when we get enough user feedback that pynccl is
  969. # more stable than cupy, we can remove this
  970. self.graph_runners.clear()
  971. self.pynccl_backend = None
  972. @property
  973. def vocab_size(self) -> int:
  974. return self.model_config.get_vocab_size()
  975. class CUDAGraphRunner:
  976. def __init__(self, model: nn.Module):
  977. self.model = model
  978. self.graph = None
  979. self.input_buffers: Dict[str, torch.Tensor] = {}
  980. self.output_buffers: Dict[str, torch.Tensor] = {}
  981. def capture(
  982. self,
  983. input_ids: torch.Tensor,
  984. positions: torch.Tensor,
  985. kv_caches: List[torch.Tensor],
  986. attn_metadata: AttentionMetadata,
  987. memory_pool,
  988. **kwargs,
  989. ) -> None:
  990. assert self.graph is None
  991. # Run the model once without capturing the graph.
  992. # This is to make sure that the captured graph does not include the
  993. # kernel launches for initial benchmarking (e.g., Triton autotune).
  994. with _maybe_pynccl():
  995. self.model(
  996. input_ids,
  997. positions,
  998. kv_caches,
  999. attn_metadata,
  1000. **kwargs,
  1001. )
  1002. torch.cuda.synchronize()
  1003. # Capture the graph.
  1004. # NOTE: Python 3.8 does not support multi-line with statements.
  1005. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
  1006. self.graph = torch.cuda.CUDAGraph()
  1007. with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
  1008. with _maybe_pynccl():
  1009. hidden_states = self.model(
  1010. input_ids,
  1011. positions,
  1012. kv_caches,
  1013. attn_metadata,
  1014. **kwargs,
  1015. )
  1016. torch.cuda.synchronize()
  1017. # Save the input and output buffers.
  1018. self.input_buffers = {
  1019. "input_ids": input_ids,
  1020. "positions": positions,
  1021. "kv_caches": kv_caches,
  1022. "slot_mapping": attn_metadata.slot_mapping,
  1023. "context_lens": attn_metadata.decode_metadata.context_lens,
  1024. "block_tables": attn_metadata.decode_metadata.block_tables,
  1025. }
  1026. self.output_buffers = {"hidden_states": hidden_states}
  1027. return
  1028. def forward(
  1029. self,
  1030. input_ids: torch.Tensor,
  1031. positions: torch.Tensor,
  1032. kv_caches: List[torch.Tensor],
  1033. attn_metadata: AttentionMetadata,
  1034. **kwargs,
  1035. ) -> torch.Tensor:
  1036. # KV caches are fixed tensors, so we don't need to copy them.
  1037. del kv_caches
  1038. # Copy the input tensors to the input buffers.
  1039. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1040. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1041. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  1042. non_blocking=True)
  1043. self.input_buffers["context_lens"].copy_(
  1044. attn_metadata.decode_metadata.context_lens, non_blocking=True)
  1045. self.input_buffers["block_tables"].copy_(
  1046. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  1047. # Run the graph.
  1048. self.graph.replay()
  1049. # Return the output tensor.
  1050. return self.output_buffers["hidden_states"]
  1051. def __call__(self, *args, **kwargs):
  1052. return self.forward(*args, **kwargs)
  1053. @contextlib.contextmanager
  1054. def _maybe_pynccl():
  1055. if pynccl_utils.is_initialized(
  1056. ) and not custom_all_reduce.is_initialized():
  1057. with with_pynccl_for_all_reduce():
  1058. yield
  1059. else:
  1060. yield
  1061. def _get_graph_batch_size(batch_size: int) -> int:
  1062. """Returns the padded batch size given actual batch size.
  1063. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1064. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1065. """
  1066. if batch_size <= 2:
  1067. return batch_size
  1068. elif batch_size <= 4:
  1069. return 4
  1070. else:
  1071. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1072. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  1073. def _prepare_fake_inputs(
  1074. seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
  1075. """Prepare fake inputs for profile run."""
  1076. if vision_language_config:
  1077. prompt_tokens = [
  1078. vision_language_config.image_token_id
  1079. ] * vision_language_config.image_feature_size + [0] * (
  1080. seq_len - vision_language_config.image_feature_size)
  1081. fake_image_input = MultiModalData(
  1082. type=MultiModalData.Type.IMAGE,
  1083. data=torch.zeros(vision_language_config.image_input_shape,
  1084. dtype=torch.float16))
  1085. else:
  1086. prompt_tokens = [0] * seq_len
  1087. fake_image_input = None
  1088. return SequenceData(prompt_tokens), fake_image_input