model_runner.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412
  1. import dataclasses
  2. import gc
  3. import time
  4. import warnings
  5. from collections import defaultdict
  6. from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
  7. TypeVar, Union)
  8. import numpy as np
  9. import torch
  10. import torch.distributed
  11. import torch.nn as nn
  12. from loguru import logger
  13. try:
  14. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  15. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  16. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  17. FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
  18. except ImportError:
  19. BatchDecodeWithPagedKVCacheWrapper = None
  20. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  21. BatchPrefillWithPagedKVCacheWrapper = None
  22. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  23. from aphrodite.attention import AttentionMetadata, get_attn_backend
  24. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  25. LoRAConfig, ModelConfig, ParallelConfig,
  26. SchedulerConfig, VisionLanguageConfig)
  27. from aphrodite.common.sampling_params import SamplingParams
  28. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  29. SequenceGroupMetadata)
  30. from aphrodite.common.utils import (CudaMemoryProfiler,
  31. get_kv_cache_torch_dtype, is_hip,
  32. is_pin_memory_available,
  33. make_tensor_with_pad)
  34. from aphrodite.distributed import get_pp_group
  35. from aphrodite.distributed.parallel_state import (
  36. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
  37. graph_capture)
  38. from aphrodite.inputs import INPUT_REGISTRY
  39. from aphrodite.lora.layers import LoRAMapping
  40. from aphrodite.lora.request import LoRARequest
  41. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  42. from aphrodite.modeling import SamplingMetadata
  43. from aphrodite.modeling.model_loader import get_model
  44. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  45. from aphrodite.modeling.models.interfaces import supports_lora
  46. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  47. from aphrodite.task_handler.model_runner_base import (
  48. ModelRunnerBase, ModelRunnerInputBase,
  49. _add_attn_metadata_broadcastable_dict,
  50. _add_sampling_metadata_broadcastable_dict,
  51. _init_attn_metadata_from_tensor_dict,
  52. _init_sampling_metadata_from_tensor_dict)
  53. if TYPE_CHECKING:
  54. from aphrodite.attention.backends.abstract import AttentionBackend
  55. _PAD_SLOT_ID = -1
  56. LORA_WARMUP_RANK = 8
  57. _BATCH_SIZE_ALIGNMENT = 8
  58. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  59. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  60. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  61. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  62. ]
  63. _NUM_WARMUP_ITERS = 2
  64. TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
  65. @dataclasses.dataclass(frozen=True)
  66. class ModelInputForGPU(ModelRunnerInputBase):
  67. """
  68. This base class contains metadata needed for the base model forward pass
  69. but not metadata for possible additional steps, e.g., sampling. Model
  70. runners that run additional steps should subclass this method to add
  71. additional fields.
  72. """
  73. input_tokens: Optional[torch.Tensor] = None
  74. input_positions: Optional[torch.Tensor] = None
  75. seq_lens: Optional[List[int]] = None
  76. query_lens: Optional[List[int]] = None
  77. lora_mapping: Optional["LoRAMapping"] = None
  78. lora_requests: Optional[Set[LoRARequest]] = None
  79. attn_metadata: Optional["AttentionMetadata"] = None
  80. multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
  81. virtual_engine: int = 0
  82. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  83. tensor_dict = {
  84. "input_tokens": self.input_tokens,
  85. "input_positions": self.input_positions,
  86. "lora_requests": self.lora_requests,
  87. "lora_mapping": self.lora_mapping,
  88. "multi_modal_kwargs": self.multi_modal_kwargs,
  89. "virtual_engine": self.virtual_engine,
  90. }
  91. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  92. return tensor_dict
  93. @classmethod
  94. def from_broadcasted_tensor_dict(
  95. cls: Type[TModelInputForGPU],
  96. tensor_dict: Dict[str, Any],
  97. attn_backend: Optional["AttentionBackend"] = None,
  98. ) -> TModelInputForGPU:
  99. if attn_backend is not None:
  100. tensor_dict = _init_attn_metadata_from_tensor_dict(
  101. attn_backend, tensor_dict)
  102. return cls(**tensor_dict)
  103. @dataclasses.dataclass(frozen=True)
  104. class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
  105. """
  106. Used by the ModelRunner.
  107. """
  108. sampling_metadata: Optional["SamplingMetadata"] = None
  109. # Used for speculative decoding. We do not broadcast it because it is only
  110. # used by the driver worker.
  111. is_prompt: Optional[bool] = None
  112. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  113. tensor_dict = {
  114. "input_tokens": self.input_tokens,
  115. "input_positions": self.input_positions,
  116. "lora_requests": self.lora_requests,
  117. "lora_mapping": self.lora_mapping,
  118. "multi_modal_kwargs": self.multi_modal_kwargs,
  119. "virtual_engine": self.virtual_engine,
  120. }
  121. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  122. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  123. self.sampling_metadata)
  124. return tensor_dict
  125. @classmethod
  126. def from_broadcasted_tensor_dict(
  127. cls,
  128. tensor_dict: Dict[str, Any],
  129. attn_backend: Optional["AttentionBackend"] = None,
  130. ) -> "ModelInputForGPUWithSamplingMetadata":
  131. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  132. if attn_backend is not None:
  133. tensor_dict = _init_attn_metadata_from_tensor_dict(
  134. attn_backend, tensor_dict)
  135. return cls(**tensor_dict)
  136. class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
  137. """
  138. Helper class for shared methods between GPU model runners.
  139. """
  140. _model_input_cls: Type[TModelInputForGPU]
  141. def __init__(
  142. self,
  143. model_config: ModelConfig,
  144. parallel_config: ParallelConfig,
  145. scheduler_config: SchedulerConfig,
  146. device_config: DeviceConfig,
  147. cache_config: CacheConfig,
  148. load_config: LoadConfig,
  149. lora_config: Optional[LoRAConfig],
  150. kv_cache_dtype: Optional[str] = "auto",
  151. is_driver_worker: bool = False,
  152. vision_language_config: Optional[VisionLanguageConfig] = None,
  153. return_hidden_states: bool = False,
  154. ):
  155. self.model_config = model_config
  156. self.parallel_config = parallel_config
  157. self.scheduler_config = scheduler_config
  158. self.device_config = device_config
  159. self.cache_config = cache_config
  160. self.lora_config = lora_config
  161. self.load_config = load_config
  162. self.is_driver_worker = is_driver_worker
  163. self.vision_language_config = vision_language_config
  164. self.return_hidden_states = return_hidden_states
  165. self.device = self.device_config.device
  166. self.pin_memory = is_pin_memory_available()
  167. self.kv_cache_dtype = kv_cache_dtype
  168. self.sliding_window = model_config.get_sliding_window()
  169. self.block_size = cache_config.block_size
  170. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  171. self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
  172. {} for _ in range(self.parallel_config.pipeline_parallel_size)
  173. ]
  174. self.graph_memory_pool: Optional[Tuple[
  175. int, int]] = None # Set during graph capture.
  176. # When using CUDA graph, the input block tables must be padded to
  177. # max_seq_len_to_capture. However, creating the block table in
  178. # Python can be expensive. To optimize this, we cache the block table
  179. # in numpy and only copy the actual input content at every iteration.
  180. # The shape of the cached block table will be
  181. # (max batch size to capture, max context len to capture / block size).
  182. self.graph_block_tables = np.zeros(
  183. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  184. dtype=np.int32)
  185. num_attn_heads = self.model_config.get_num_attention_heads(
  186. self.parallel_config)
  187. self.attn_backend = get_attn_backend(
  188. num_attn_heads,
  189. self.model_config.get_head_size(),
  190. self.model_config.get_num_kv_heads(self.parallel_config),
  191. self.model_config.get_sliding_window(),
  192. self.model_config.dtype,
  193. self.kv_cache_dtype,
  194. self.block_size,
  195. ) if num_attn_heads else None
  196. # Multi-modal data support
  197. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  198. .create_input_mapper(self.model_config)
  199. # Lazy initialization
  200. self.model: nn.Module # Set after load_model
  201. # Set after load_model.
  202. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  203. self.flashinfer_decode_workspace_buffer = None
  204. self.flashinfer_decode_wrapper = None
  205. self.flashinfer_prefill_workspace_buffer = None
  206. self.flashinfer_prefill_wrapper = None
  207. def load_model(self) -> None:
  208. with CudaMemoryProfiler() as m:
  209. # measure the time it takes to load the model
  210. start_time = time.time()
  211. self.model = get_model(
  212. model_config=self.model_config,
  213. device_config=self.device_config,
  214. load_config=self.load_config,
  215. lora_config=self.lora_config,
  216. vision_language_config=self.vision_language_config,
  217. parallel_config=self.parallel_config,
  218. scheduler_config=self.scheduler_config,
  219. cache_config=self.cache_config,
  220. )
  221. end_time = time.time()
  222. self.model_memory_usage = m.consumed_memory
  223. tp = get_tensor_model_parallel_world_size()
  224. rank = get_tensor_model_parallel_rank()
  225. total_time = end_time - start_time
  226. if tp > 1:
  227. logger.info(
  228. f"Rank {rank}: Model weights loaded in {total_time:.2f} secs.")
  229. if rank == 0:
  230. logger.info(
  231. "Memory usage: "
  232. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} ="
  233. f" {self.model_memory_usage * tp / float(2**30):.2f} GiB")
  234. else:
  235. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  236. logger.info("Memory usage: "
  237. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  238. if self.lora_config:
  239. assert supports_lora(self.model), "Model does not support LoRA"
  240. self.lora_manager = LRUCacheWorkerLoRAManager(
  241. self.scheduler_config.max_num_seqs,
  242. self.scheduler_config.max_num_batched_tokens,
  243. self.vocab_size,
  244. self.lora_config,
  245. self.device,
  246. self.model.embedding_modules,
  247. self.model.embedding_padding_modules,
  248. max_position_embeddings=self.model.config.
  249. max_position_embeddings,
  250. )
  251. self.model = self.lora_manager.create_lora_manager(self.model)
  252. if self.kv_cache_dtype == "fp8" and is_hip():
  253. # Currently only ROCm accepts kv-cache scaling factors
  254. # via quantization_param_path and this will be deprecated
  255. # in the future.
  256. if self.model_config.quantization_param_path is not None:
  257. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  258. warnings.warn(
  259. "Loading kv cache scaling factor from JSON is "
  260. "deprecated and will be removed. Please include "
  261. "kv cache scaling factors in the model checkpoint.",
  262. FutureWarning,
  263. stacklevel=2)
  264. self.model.load_kv_cache_scales(
  265. self.model_config.quantization_param_path)
  266. logger.info(
  267. "Loaded KV cache scaling factors from ",
  268. f"{self.model_config.quantization_param_path}")
  269. else:
  270. raise RuntimeError(
  271. "Using FP8 KV cache and scaling factors provided but "
  272. f"model {self.model.__class__} does not support loading"
  273. " scaling factors.", )
  274. else:
  275. logger.warning(
  276. "Using FP8 KV cache but no scaling factors "
  277. "provided. Defaulting to scaling factors of 1.0. "
  278. "This may lead to less accurate results!")
  279. def save_sharded_state(
  280. self,
  281. path: str,
  282. pattern: Optional[str] = None,
  283. max_size: Optional[int] = None,
  284. ) -> None:
  285. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  286. ShardedStateLoader.save_model(
  287. self.model,
  288. path,
  289. pattern=pattern,
  290. max_size=max_size,
  291. )
  292. def save_tensorized_model(
  293. self,
  294. tensorizer_config: TensorizerConfig,
  295. ) -> None:
  296. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  297. TensorizerLoader.save_model(
  298. self.model,
  299. tensorizer_config=tensorizer_config,
  300. )
  301. def get_max_block_per_batch(self) -> int:
  302. block_size = self.block_size
  303. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  304. def _prepare_model_input_tensors(
  305. self,
  306. seq_group_metadata_list: List[SequenceGroupMetadata],
  307. ) -> TModelInputForGPU:
  308. """Helper method to prepare the model input based on a given sequence
  309. group. Prepares metadata needed for the base model forward pass but not
  310. metadata for possible additional steps, e.g., sampling.
  311. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  312. The result tensors and data structure also batches input in prefill
  313. -> decode order. For example,
  314. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  315. - input_tokens[num_prefill_tokens:] contains decode tokens.
  316. If cuda graph is required, this API automatically pads inputs.
  317. """
  318. input_tokens: List[int] = []
  319. input_positions: List[int] = []
  320. slot_mapping: List[int] = []
  321. lora_index_mapping: List[int] = []
  322. lora_prompt_mapping: List[int] = []
  323. lora_requests: Set[LoRARequest] = set()
  324. seq_lens: List[int] = []
  325. prefill_seq_lens: List[int] = []
  326. decode_seq_lens: List[int] = []
  327. context_lens: List[int] = []
  328. query_lens: List[int] = []
  329. block_tables: List[List[int]] = []
  330. multi_modal_kwargs_list: Dict[str,
  331. List[torch.Tensor]] = defaultdict(list)
  332. decode_only = True
  333. num_prefills = 0
  334. num_prefill_tokens = 0
  335. num_decode_tokens = 0
  336. # The following fields are only for flashinfer
  337. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  338. # for the precise definition of the following fields.
  339. # An example:
  340. # request 1, page indices [0, 5, 8]
  341. # request 2, page indices [1, 6, 7]
  342. # request 3, page indices [3, 4]
  343. # paged_kv_indices is a concatenation of page indices of all requests:
  344. # [0, 5, 8, 1, 6, 7, 3, 4]
  345. # paged_kv_indptr is used to index into paged_kv_indices:
  346. # [0, 3, 6, 8]
  347. paged_kv_indices: List[int] = []
  348. # 0 at the beginning of paged_kv_indptr indicates the start of the
  349. # first request’s page indices in the paged_kv_indices list.
  350. paged_kv_indptr: List[int] = [0]
  351. # paged_kv_last_page_len is the length of the last page of each request
  352. paged_kv_last_page_len: List[int] = []
  353. if len(seq_group_metadata_list) == 0:
  354. return self._model_input_cls()
  355. if self.sliding_window is not None:
  356. sliding_window_blocks = (self.sliding_window + self.block_size -
  357. 1) // self.block_size
  358. block_aligned_sliding_window = \
  359. sliding_window_blocks * self.block_size
  360. for seq_group_metadata in seq_group_metadata_list:
  361. seq_ids = list(seq_group_metadata.seq_data.keys())
  362. is_prompt = seq_group_metadata.is_prompt
  363. for seq_id in seq_ids:
  364. computed_block_nums = seq_group_metadata.computed_block_nums
  365. if (self.scheduler_config is not None
  366. and self.scheduler_config.chunked_prefill_enabled
  367. and not (computed_block_nums is None
  368. or computed_block_nums == [])):
  369. raise RuntimeError(
  370. "chunked prefill cannot be used with prefix caching "
  371. "now.")
  372. seq_data = seq_group_metadata.seq_data[seq_id]
  373. if is_prompt:
  374. context_len = seq_data.get_num_computed_tokens()
  375. else:
  376. # get_num_computed_tokens is incorrect for spec decoding.
  377. # So, we should have a special logic here.
  378. # TODO: Fix it.
  379. context_len = seq_data.get_len() - 1
  380. seq_len = min(
  381. seq_data.get_len(),
  382. context_len + seq_group_metadata.token_chunk_size)
  383. if is_prompt:
  384. tokens = seq_data.get_token_ids()[context_len:seq_len]
  385. else:
  386. # Optimization. get_token_ids requires the entire copy of
  387. # tokens.
  388. tokens = [seq_data.get_last_token_id()]
  389. # Prefix cache was hit.
  390. # Prefix is not supported with sliding_window
  391. prefix_cache_hit = (computed_block_nums is not None
  392. and len(computed_block_nums) > 0
  393. and self.sliding_window is None
  394. and is_prompt)
  395. # These are seq_len/context_len capped to the sliding window.
  396. # They are passed to decode kernel.
  397. # We still need original seq_len/context_len to compute slot
  398. # mapping (and input position) below.
  399. curr_sliding_window_blocks = None
  400. sliding_seq_len = seq_len
  401. sliding_context_len = context_len
  402. # TODO: This is a hack to make sliding window work with
  403. # paged attn. We can remove it if we make paged attn kernel
  404. # to properly handle slinding window attn.
  405. if (self.sliding_window is not None and not is_prompt):
  406. curr_sliding_window_blocks = sliding_window_blocks
  407. if self.scheduler_config.use_v2_block_manager:
  408. # number of elements in last block
  409. suff_len = seq_len % self.block_size
  410. sliding_seq_len = min(
  411. seq_len, block_aligned_sliding_window + suff_len)
  412. if suff_len > 0:
  413. curr_sliding_window_blocks += 1
  414. else:
  415. sliding_seq_len = min(seq_len, self.sliding_window)
  416. sliding_context_len = sliding_seq_len - 1
  417. # TODO: Combine chunked prefill and prefix caching by
  418. # only allowing multiple of block_size chunk size.
  419. # NOTE: This only works for oooooooxxx style attention.
  420. if prefix_cache_hit:
  421. assert computed_block_nums is not None
  422. context_len = len(computed_block_nums) * self.block_size
  423. tokens = tokens[context_len:]
  424. # need to think what to set it to when we have both sliding
  425. # window and prefix caching...
  426. assert self.sliding_window is None, \
  427. "Prefix caching is not supported with sliding window"
  428. sliding_context_len = context_len
  429. if self.attn_backend.get_name() == "flash-attn":
  430. # NOTE: For flash-attn, the block table should
  431. # include the entries for the incoming prefill tokens.
  432. # TODO: This is a temporary fix. We should
  433. # provide a unified interface for different backends.
  434. block_table = seq_group_metadata.block_tables[seq_id]
  435. else:
  436. block_table = computed_block_nums
  437. elif (self.scheduler_config.chunked_prefill_enabled
  438. or not is_prompt):
  439. if seq_group_metadata.block_tables is not None:
  440. # chunked prefill or decode
  441. block_table = seq_group_metadata.block_tables[seq_id]
  442. if curr_sliding_window_blocks is not None:
  443. block_table = block_table[
  444. -curr_sliding_window_blocks:]
  445. else:
  446. # Only happens when memory profiling runs.
  447. block_table = []
  448. else:
  449. # Prefill without chunked prefill or memory profiling.
  450. block_table = []
  451. block_tables.append(block_table)
  452. seq_lens.append(sliding_seq_len)
  453. context_lens.append(sliding_context_len)
  454. query_len = sliding_seq_len - sliding_context_len
  455. query_lens.append(query_len)
  456. input_tokens.extend(tokens)
  457. input_positions.extend(list(range(context_len, seq_len)))
  458. lora_id = seq_group_metadata.lora_int_id
  459. if is_prompt:
  460. assert len(seq_ids) == 1
  461. num_prefills += 1
  462. num_prefill_tokens += len(tokens)
  463. decode_only = False
  464. prefill_seq_lens.append(seq_len)
  465. else:
  466. assert query_len == 1, (
  467. "seq_len: {}, context_len: {}, query_len: {}".format(
  468. seq_len, context_len, query_len))
  469. num_decode_tokens += query_len
  470. decode_seq_lens.append(sliding_seq_len)
  471. if lora_id > 0:
  472. lora_requests.add(seq_group_metadata.lora_request)
  473. lora_index_mapping += [lora_id] * query_len
  474. lora_prompt_mapping.extend(
  475. [lora_id] *
  476. (query_len if seq_group_metadata.sampling_params
  477. and seq_group_metadata.sampling_params.prompt_logprobs
  478. is not None else 1))
  479. mm_data = seq_group_metadata.multi_modal_data
  480. if mm_data:
  481. # Process multi-modal data
  482. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  483. for k, v in mm_kwargs.items():
  484. multi_modal_kwargs_list[k].append(v)
  485. is_profile_run = _is_block_tables_empty(
  486. seq_group_metadata.block_tables)
  487. if is_profile_run:
  488. # During memory profiling, the block tables are not
  489. # initialized yet. In this case, we just use a dummy
  490. # slot mapping.
  491. # In embeddings, the block tables are {seq_id: None}.
  492. slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  493. continue
  494. # Compute the slot mapping.
  495. block_table = seq_group_metadata.block_tables[seq_id]
  496. # Mask the [0, start_idx) tokens of the prompt with
  497. # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
  498. # sliding_window). For example, if the prompt len is 10,
  499. # sliding window is 8, and block size is 4, the first two
  500. # tokens are masked and the slot mapping will be
  501. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  502. start_idx = 0
  503. if self.sliding_window is not None:
  504. if is_prompt:
  505. assert self.scheduler_config.use_v2_block_manager \
  506. or context_len == 0, (
  507. "Prefix caching is currently not supported with "
  508. "sliding window attention in V1 block manager")
  509. # It is an optimization. When it is decoding, it is always
  510. # 0. When prefill, we use it to not write slots to kv cache
  511. # to save memory.
  512. start_idx = max(0, query_len - self.sliding_window)
  513. for i in range(context_len, seq_len):
  514. if i < start_idx:
  515. slot_mapping.append(_PAD_SLOT_ID)
  516. continue
  517. block_number = block_table[i // self.block_size]
  518. block_offset = i % self.block_size
  519. slot = block_number * self.block_size + block_offset
  520. slot_mapping.append(slot)
  521. # Prepare input tensors for flashinfer
  522. if self.attn_backend.get_name() == "flashinfer":
  523. seq_len = seq_data.get_len()
  524. # Get the number of valid blocks based on sequence length.
  525. # If seq_len = 16, block_size = 16,
  526. # block_table_bound is 1 with 1 valid block.
  527. # If seq_len = 15, block_size = 16,
  528. # block_table_bound is 0 + 1 with 1 valid block.
  529. block_table_bound = seq_len // self.block_size + 1 \
  530. if seq_len % self.block_size != 0 \
  531. else seq_len // self.block_size
  532. paged_kv_indices.extend(block_table[:block_table_bound])
  533. paged_kv_indptr.append(paged_kv_indptr[-1] +
  534. block_table_bound)
  535. last_page_len = seq_len % self.block_size
  536. if last_page_len == 0:
  537. last_page_len = self.block_size
  538. paged_kv_last_page_len.append(last_page_len)
  539. batch_size = len(input_tokens)
  540. max_query_len = max(query_lens)
  541. max_prefill_seq_len = max(prefill_seq_lens, default=0)
  542. max_decode_seq_len = max(decode_seq_lens, default=0)
  543. # If cuda graph can be used, pad tensors accordingly.
  544. # See `capture_model` API for more details.
  545. # Aphrodite uses cuda graph only for decoding requests.
  546. use_captured_graph = (
  547. decode_only and not self.model_config.enforce_eager
  548. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  549. and max_decode_seq_len <= self.max_seq_len_to_capture)
  550. if use_captured_graph:
  551. graph_batch_size = _get_graph_batch_size(batch_size)
  552. assert graph_batch_size >= batch_size
  553. for _ in range(graph_batch_size - batch_size):
  554. input_tokens.append(0)
  555. input_positions.append(0)
  556. slot_mapping.append(_PAD_SLOT_ID)
  557. seq_lens.append(1)
  558. block_tables.append([])
  559. lora_index_mapping.append(0)
  560. if self.attn_backend.get_name() == "flashinfer":
  561. last_paged_kv_indptr = paged_kv_indptr[-1]
  562. paged_kv_indptr.append(last_paged_kv_indptr)
  563. paged_kv_last_page_len.append(0)
  564. batch_size = graph_batch_size
  565. num_decode_tokens = batch_size
  566. if use_captured_graph:
  567. # The shape of graph_block_tables is
  568. # [max batch size, max context len // block size].
  569. input_block_tables = self.graph_block_tables[:batch_size]
  570. for i, block_table in enumerate(block_tables):
  571. if block_table:
  572. input_block_tables[i, :len(block_table)] = block_table
  573. block_tables = torch.tensor(input_block_tables, device=self.device)
  574. else:
  575. max_block_table_len = max(
  576. len(block_table) for block_table in block_tables)
  577. block_tables = make_tensor_with_pad(
  578. block_tables,
  579. max_len=max_block_table_len,
  580. pad=0,
  581. dtype=torch.int,
  582. device=self.device,
  583. )
  584. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  585. context_lens_tensor = torch.tensor(context_lens,
  586. dtype=torch.int,
  587. device=self.device)
  588. seq_lens_tensor = torch.tensor(seq_lens,
  589. dtype=torch.int,
  590. device=self.device)
  591. query_lens_tensor = torch.tensor(query_lens,
  592. dtype=torch.long,
  593. device=self.device)
  594. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  595. dtype=torch.int32,
  596. device=self.device)
  597. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  598. dtype=torch.int32,
  599. device=self.device)
  600. torch.cumsum(seq_lens_tensor,
  601. dim=0,
  602. dtype=seq_start_loc.dtype,
  603. out=seq_start_loc[1:])
  604. torch.cumsum(query_lens_tensor,
  605. dim=0,
  606. dtype=query_start_loc.dtype,
  607. out=query_start_loc[1:])
  608. input_tokens_tensor = torch.tensor(input_tokens,
  609. dtype=torch.long,
  610. device=self.device)
  611. input_positions_tensor = torch.tensor(input_positions,
  612. dtype=torch.long,
  613. device=self.device)
  614. slot_mapping_tensor = torch.tensor(slot_mapping,
  615. dtype=torch.long,
  616. device=self.device)
  617. if self.attn_backend.get_name() == "flashinfer":
  618. if len(paged_kv_indptr) > 0:
  619. paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
  620. device='cpu',
  621. dtype=torch.int)
  622. paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
  623. device='cpu',
  624. dtype=torch.int)
  625. paged_kv_last_page_len_tensor = torch.tensor(
  626. paged_kv_last_page_len, device='cpu', dtype=torch.int)
  627. else:
  628. paged_kv_indices_tensor = None
  629. paged_kv_indptr_tensor = None
  630. paged_kv_last_page_len_tensor = None
  631. kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
  632. self.model_config.dtype)
  633. attn_metadata = self.attn_backend.make_metadata(
  634. num_prefills=num_prefills,
  635. slot_mapping=slot_mapping_tensor,
  636. num_prefill_tokens=num_prefill_tokens,
  637. num_decode_tokens=num_decode_tokens,
  638. max_prefill_seq_len=max_prefill_seq_len,
  639. block_tables=block_tables,
  640. paged_kv_indptr=paged_kv_indptr_tensor,
  641. paged_kv_indices=paged_kv_indices_tensor,
  642. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  643. num_qo_heads=self.model_config.get_num_attention_heads(
  644. self.parallel_config),
  645. num_kv_heads=self.model_config.get_num_kv_heads(
  646. self.parallel_config),
  647. head_dim=self.model_config.get_head_size(),
  648. page_size=self.block_size,
  649. seq_start_loc=seq_start_loc,
  650. query_start_loc=query_start_loc,
  651. device=self.device,
  652. data_type=kv_cache_dtype,
  653. use_cuda_graph=use_captured_graph)
  654. else:
  655. attn_metadata = self.attn_backend.make_metadata(
  656. num_prefills=num_prefills,
  657. slot_mapping=slot_mapping_tensor,
  658. num_prefill_tokens=num_prefill_tokens,
  659. num_decode_tokens=num_decode_tokens,
  660. seq_lens=seq_lens,
  661. seq_lens_tensor=seq_lens_tensor,
  662. max_query_len=max_query_len,
  663. max_prefill_seq_len=max_prefill_seq_len,
  664. max_decode_seq_len=max_decode_seq_len,
  665. query_start_loc=query_start_loc,
  666. seq_start_loc=seq_start_loc,
  667. context_lens_tensor=context_lens_tensor,
  668. block_tables=block_tables,
  669. use_cuda_graph=use_captured_graph,
  670. )
  671. if self.lora_config:
  672. lora_mapping = LoRAMapping(
  673. lora_index_mapping,
  674. lora_prompt_mapping,
  675. )
  676. else:
  677. lora_mapping = None
  678. multi_modal_kwargs = {
  679. k: torch.cat(v, dim=0).to(self.device)
  680. for k, v in multi_modal_kwargs_list.items()
  681. }
  682. return self._model_input_cls(
  683. input_tokens=input_tokens_tensor,
  684. input_positions=input_positions_tensor,
  685. attn_metadata=attn_metadata,
  686. seq_lens=seq_lens,
  687. query_lens=query_lens,
  688. lora_mapping=lora_mapping,
  689. lora_requests=lora_requests,
  690. multi_modal_kwargs=multi_modal_kwargs,
  691. )
  692. @torch.inference_mode()
  693. def profile_run(self) -> None:
  694. # Enable top-k sampling to reflect the accurate memory usage.
  695. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  696. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  697. max_num_seqs = self.scheduler_config.max_num_seqs
  698. # This represents the maximum number of different requests
  699. # that will have unique loras, an therefore the max amount of memory
  700. # consumption create dummy lora request copies from the lora request
  701. # passed in, which contains a lora from the lora warmup path.
  702. dummy_lora_requests: List[LoRARequest] = []
  703. dummy_lora_requests_per_seq: List[LoRARequest] = []
  704. if self.lora_config:
  705. assert self.lora_manager is not None
  706. with self.lora_manager.dummy_lora_cache():
  707. for idx in range(self.lora_config.max_loras):
  708. lora_id = idx + 1
  709. dummy_lora_request = LoRARequest(
  710. lora_name=f"warmup_{lora_id}",
  711. lora_int_id=lora_id,
  712. lora_local_path="/not/a/real/path",
  713. )
  714. self.lora_manager.add_dummy_lora(dummy_lora_request,
  715. rank=LORA_WARMUP_RANK)
  716. dummy_lora_requests.append(dummy_lora_request)
  717. dummy_lora_requests_per_seq = [
  718. dummy_lora_requests[idx % len(dummy_lora_requests)]
  719. for idx in range(max_num_seqs)
  720. ]
  721. # Profile memory usage with max_num_sequences sequences and the total
  722. # number of tokens equal to max_num_batched_tokens.
  723. seqs: List[SequenceGroupMetadata] = []
  724. # Additional GPU memory may be needed for vision encoding, which needs
  725. # to be accounted for when calculating the GPU blocks for
  726. # Aphrodite blocker manager.
  727. # To exercise the worst scenario for GPU memory consumption,
  728. # the number of seqs (batch_size) is chosen to maximize the number
  729. # of images processed.
  730. model_config = self.model_config
  731. vlm_config = self.vision_language_config
  732. if vlm_config:
  733. max_num_seqs = min(
  734. max_num_seqs,
  735. int(max_num_batched_tokens / vlm_config.image_feature_size))
  736. batch_size = 0
  737. for group_id in range(max_num_seqs):
  738. seq_len = (max_num_batched_tokens // max_num_seqs +
  739. (group_id < max_num_batched_tokens % max_num_seqs))
  740. batch_size += seq_len
  741. seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
  742. .dummy_data_for_profiling(model_config, seq_len)
  743. assert len(seq_data.prompt_token_ids) == seq_len
  744. seq = SequenceGroupMetadata(
  745. request_id=str(group_id),
  746. is_prompt=True,
  747. seq_data={group_id: seq_data},
  748. sampling_params=sampling_params,
  749. block_tables=None,
  750. lora_request=dummy_lora_requests_per_seq[group_id]
  751. if dummy_lora_requests_per_seq else None,
  752. multi_modal_data=dummy_multi_modal_data,
  753. )
  754. seqs.append(seq)
  755. # Run the model with the dummy inputs.
  756. num_layers = self.model_config.get_num_layers(self.parallel_config)
  757. kv_caches = [None] * num_layers
  758. model_input = self.prepare_model_input(seqs)
  759. intermediate_tensors = None
  760. if not get_pp_group().is_first_rank:
  761. intermediate_tensors = self.model.make_empty_intermediate_tensors(
  762. batch_size=batch_size,
  763. dtype=self.model_config.dtype,
  764. device=self.device)
  765. self.execute_model(model_input, kv_caches, intermediate_tensors)
  766. torch.cuda.synchronize()
  767. return
  768. def remove_all_loras(self):
  769. if not self.lora_manager:
  770. raise RuntimeError("LoRA is not enabled.")
  771. self.lora_manager.remove_all_loras()
  772. def set_active_loras(self, lora_requests: Set[LoRARequest],
  773. lora_mapping: LoRAMapping) -> None:
  774. if not self.lora_manager:
  775. raise RuntimeError("LoRA is not enabled.")
  776. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  777. def add_lora(self, lora_request: LoRARequest) -> bool:
  778. if not self.lora_manager:
  779. raise RuntimeError("LoRA is not enabled.")
  780. return self.lora_manager.add_lora(lora_request)
  781. def remove_lora(self, lora_id: int) -> bool:
  782. if not self.lora_manager:
  783. raise RuntimeError("LoRA is not enabled.")
  784. return self.lora_manager.remove_lora(lora_id)
  785. def pin_lora(self, lora_id: int) -> bool:
  786. if not self.lora_manager:
  787. raise RuntimeError("LoRA is not enabled.")
  788. return self.lora_manager.pin_lora(lora_id)
  789. def list_loras(self) -> Set[int]:
  790. if not self.lora_manager:
  791. raise RuntimeError("LoRA is not enabled.")
  792. return self.lora_manager.list_loras()
  793. @torch.inference_mode()
  794. def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
  795. """Cuda graph capture a model.
  796. Note that CUDA graph's performance gain is negligible if number
  797. of batched tokens are larger than 200. And since CUDA graph
  798. requires fixed sized tensors, supporting large/variable batch
  799. size requires high GPU memory overhead. Thus, Aphrodite only captures
  800. decoding requests. Mixed batch (chunked prefill + decoding) or
  801. prefill requests are not captured.
  802. Since it is used for decoding-only, it assumes there's only 1 token
  803. per sequence in the batch.
  804. """
  805. assert not self.model_config.enforce_eager
  806. logger.info("Capturing the model for CUDA graphs. This may lead to "
  807. "unexpected consequences if the model is not static. To "
  808. "run the model in eager mode, set 'enforce_eager=True' or "
  809. "use '--enforce-eager' in the CLI.")
  810. logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
  811. "If you are running out of memory, consider decreasing "
  812. "`gpu_memory_utilization` or enforcing eager mode. "
  813. "You can also reduce the `max_num_seqs` as needed "
  814. "to decrease memory usage.")
  815. start_time = time.perf_counter()
  816. # Prepare dummy inputs. These will be reused for all batch sizes.
  817. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  818. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  819. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  820. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  821. slot_mapping.fill_(_PAD_SLOT_ID)
  822. seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  823. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  824. intermediate_inputs = None
  825. if not get_pp_group().is_first_rank:
  826. intermediate_inputs = self.model.make_empty_intermediate_tensors(
  827. batch_size=max_batch_size,
  828. dtype=self.model_config.dtype,
  829. device=self.device)
  830. # Prepare buffer for outputs. These will be reused for all batch sizes.
  831. # It will be filled after the first graph capture.
  832. hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
  833. None
  834. ] * self.parallel_config.pipeline_parallel_size
  835. graph_batch_size = _get_graph_batch_size(
  836. self.scheduler_config.max_num_seqs)
  837. batch_size_capture_list = [
  838. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  839. ]
  840. if self.attn_backend.get_name() == "flashinfer":
  841. # For flashinfer, different batch sizes will share the
  842. # same workspace buffer.
  843. decode_workspace_buffer = \
  844. torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
  845. dtype=torch.uint8,
  846. device=self.device)
  847. indices_buffer = torch.empty(max_batch_size *
  848. self.cache_config.num_gpu_blocks,
  849. dtype=torch.int32,
  850. device=self.device)
  851. indptr_buffer = torch.empty(max_batch_size + 1,
  852. dtype=torch.int32,
  853. device=self.device)
  854. last_page_len_buffer = torch.empty(max_batch_size,
  855. dtype=torch.int32,
  856. device=self.device)
  857. with graph_capture() as graph_capture_context:
  858. # NOTE: Capturing the largest batch size first may help reduce the
  859. # memory usage of CUDA graph.
  860. for virtual_engine in range(
  861. self.parallel_config.pipeline_parallel_size):
  862. for batch_size in reversed(batch_size_capture_list):
  863. if self.attn_backend.get_name() == "flashinfer":
  864. indptr_buffer = indptr_buffer[:batch_size + 1]
  865. last_page_len_buffer = last_page_len_buffer[:
  866. batch_size]
  867. num_qo_heads = (
  868. self.model_config.get_num_attention_heads(
  869. self.parallel_config))
  870. num_kv_heads = self.model_config.get_num_kv_heads(
  871. self.parallel_config)
  872. if num_qo_heads // num_kv_heads >= 4:
  873. use_tensor_cores = True
  874. else:
  875. use_tensor_cores = False
  876. decode_wrapper = \
  877. CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
  878. decode_workspace_buffer, indptr_buffer,
  879. indices_buffer, last_page_len_buffer, "NHD",
  880. use_tensor_cores)
  881. kv_cache_dtype = get_kv_cache_torch_dtype(
  882. self.kv_cache_dtype, self.model_config.dtype)
  883. paged_kv_indptr_tensor_host = torch.arange(
  884. 0, batch_size + 1, dtype=torch.int32)
  885. paged_kv_indices_tensor_host = torch.arange(
  886. 0, batch_size, dtype=torch.int32)
  887. paged_kv_last_page_len_tensor_host = torch.full(
  888. (batch_size, ), self.block_size, dtype=torch.int32)
  889. query_start_loc_host = torch.arange(0,
  890. batch_size + 1,
  891. dtype=torch.int32)
  892. attn_metadata = self.attn_backend.make_metadata(
  893. num_prefills=0,
  894. slot_mapping=slot_mapping[:batch_size],
  895. num_prefill_tokens=0,
  896. num_decode_tokens=batch_size,
  897. max_prefill_seq_len=0,
  898. block_tables=block_tables,
  899. paged_kv_indptr=paged_kv_indptr_tensor_host,
  900. paged_kv_indices=paged_kv_indices_tensor_host,
  901. paged_kv_last_page_len=
  902. paged_kv_last_page_len_tensor_host,
  903. num_qo_heads=num_qo_heads,
  904. num_kv_heads=num_kv_heads,
  905. head_dim=self.model_config.get_head_size(),
  906. page_size=self.block_size,
  907. seq_start_loc=None,
  908. query_start_loc=query_start_loc_host,
  909. device=self.device,
  910. data_type=kv_cache_dtype,
  911. use_cuda_graph=True,
  912. decode_wrapper=decode_wrapper,
  913. prefill_wrapper=None)
  914. attn_metadata.begin_forward()
  915. else:
  916. attn_metadata = self.attn_backend.make_metadata(
  917. num_prefills=0,
  918. num_prefill_tokens=0,
  919. num_decode_tokens=batch_size,
  920. slot_mapping=slot_mapping[:batch_size],
  921. seq_lens=None,
  922. seq_lens_tensor=seq_lens[:batch_size],
  923. max_query_len=None,
  924. max_prefill_seq_len=0,
  925. max_decode_seq_len=self.max_seq_len_to_capture,
  926. query_start_loc=None,
  927. seq_start_loc=None,
  928. context_lens_tensor=None,
  929. block_tables=block_tables[:batch_size],
  930. use_cuda_graph=True,
  931. )
  932. if self.lora_config:
  933. lora_mapping = LoRAMapping(
  934. [0] * batch_size,
  935. [0] * batch_size,
  936. )
  937. self.set_active_loras(set(), lora_mapping)
  938. graph_runner = CUDAGraphRunner(
  939. self.model, self.attn_backend.get_name())
  940. if self.attn_backend.get_name() == "flashinfer":
  941. graph_runner.flashinfer_indptr_buffer = indptr_buffer
  942. graph_runner.flashinfer_indices_buffer = indices_buffer
  943. graph_runner.flashinfer_last_page_len_buffer = \
  944. last_page_len_buffer
  945. graph_runner.flashinfer_decode_workspace_buffer = \
  946. decode_workspace_buffer
  947. graph_runner.flashinfer_decode_wrapper = \
  948. decode_wrapper
  949. graph_runner.capture(
  950. input_tokens[:batch_size],
  951. input_positions[:batch_size],
  952. hidden_or_intermediate_states[
  953. virtual_engine] # type: ignore
  954. [:batch_size]
  955. if hidden_or_intermediate_states[virtual_engine]
  956. is not None else None,
  957. intermediate_inputs[:batch_size]
  958. if intermediate_inputs is not None else None,
  959. kv_caches[virtual_engine],
  960. attn_metadata,
  961. memory_pool=self.graph_memory_pool,
  962. stream=graph_capture_context.stream,
  963. )
  964. self.graph_memory_pool = graph_runner.graph.pool()
  965. self.graph_runners[virtual_engine][batch_size] = (
  966. graph_runner)
  967. end_time = time.perf_counter()
  968. elapsed_time = end_time - start_time
  969. # This usually takes < 10 seconds.
  970. logger.info(f"Graph capturing finished in {elapsed_time:2f} secs.")
  971. @property
  972. def vocab_size(self) -> int:
  973. return self.model_config.get_vocab_size()
  974. class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
  975. """
  976. GPU model runner with sampling step.
  977. """
  978. _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
  979. ModelInputForGPUWithSamplingMetadata)
  980. def make_model_input_from_broadcasted_tensor_dict(
  981. self,
  982. tensor_dict: Dict[str, Any],
  983. ) -> ModelInputForGPUWithSamplingMetadata:
  984. model_input = \
  985. ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
  986. tensor_dict,
  987. attn_backend=self.attn_backend,
  988. )
  989. return model_input
  990. def prepare_model_input(
  991. self,
  992. seq_group_metadata_list: List[SequenceGroupMetadata],
  993. virtual_engine: int = 0,
  994. ) -> ModelInputForGPUWithSamplingMetadata:
  995. """Prepare the model input based on a given sequence group, including
  996. metadata for the sampling step.
  997. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  998. The result tensors and data structure also batches input in prefill
  999. -> decode order. For example,
  1000. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  1001. - input_tokens[num_prefill_tokens:] contains decode tokens.
  1002. If cuda graph is required, this API automatically pads inputs.
  1003. """
  1004. model_input = self._prepare_model_input_tensors(
  1005. seq_group_metadata_list)
  1006. sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
  1007. model_input.seq_lens,
  1008. model_input.query_lens,
  1009. self.device,
  1010. self.pin_memory)
  1011. is_prompt = (seq_group_metadata_list[0].is_prompt
  1012. if seq_group_metadata_list else None)
  1013. return dataclasses.replace(model_input,
  1014. sampling_metadata=sampling_metadata,
  1015. is_prompt=is_prompt,
  1016. virtual_engine=virtual_engine)
  1017. @torch.inference_mode()
  1018. def execute_model(
  1019. self,
  1020. model_input: ModelInputForGPUWithSamplingMetadata,
  1021. kv_caches: List[torch.Tensor],
  1022. intermediate_tensors: Optional[IntermediateTensors] = None,
  1023. num_steps: int = 1,
  1024. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  1025. if num_steps > 1:
  1026. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  1027. if self.lora_config:
  1028. assert model_input.lora_requests is not None
  1029. assert model_input.lora_mapping is not None
  1030. self.set_active_loras(model_input.lora_requests,
  1031. model_input.lora_mapping)
  1032. if self.attn_backend.get_name() == "flashinfer":
  1033. assert model_input.attn_metadata is not None
  1034. assert model_input.input_tokens is not None
  1035. if self.flashinfer_decode_workspace_buffer is None:
  1036. self.flashinfer_decode_workspace_buffer = torch.empty(
  1037. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  1038. dtype=torch.uint8,
  1039. device=self.device)
  1040. self.flashinfer_decode_wrapper = \
  1041. BatchDecodeWithPagedKVCacheWrapper(
  1042. self.flashinfer_decode_workspace_buffer, "NHD")
  1043. self.flashinfer_prefill_workspace_buffer = torch.empty(
  1044. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  1045. dtype=torch.uint8,
  1046. device=self.device)
  1047. self.flashinfer_prefill_wrapper = \
  1048. BatchPrefillWithPagedKVCacheWrapper(
  1049. self.flashinfer_prefill_workspace_buffer, "NHD")
  1050. model_input.attn_metadata.prefill_wrapper = \
  1051. self.flashinfer_prefill_wrapper
  1052. if model_input.attn_metadata.use_cuda_graph:
  1053. batch_size = model_input.input_tokens.shape[0]
  1054. model_input.attn_metadata.decode_wrapper = self.graph_runners[
  1055. batch_size].flashinfer_decode_wrapper
  1056. else:
  1057. model_input.attn_metadata.decode_wrapper = \
  1058. self.flashinfer_decode_wrapper
  1059. model_input.attn_metadata.begin_forward()
  1060. # Currently cuda graph is only supported by the decode phase.
  1061. assert model_input.attn_metadata is not None
  1062. prefill_meta = model_input.attn_metadata.prefill_metadata
  1063. decode_meta = model_input.attn_metadata.decode_metadata
  1064. # TODO: We can remove this once all
  1065. # virtual engines share the same kv cache.
  1066. virtual_engine = model_input.virtual_engine
  1067. if prefill_meta is None and decode_meta.use_cuda_graph:
  1068. assert model_input.input_tokens is not None
  1069. graph_batch_size = model_input.input_tokens.shape[0]
  1070. model_executable = self.graph_runners[virtual_engine][
  1071. graph_batch_size]
  1072. else:
  1073. model_executable = self.model
  1074. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  1075. hidden_or_intermediate_states = model_executable(
  1076. input_ids=model_input.input_tokens,
  1077. positions=model_input.input_positions,
  1078. kv_caches=kv_caches,
  1079. attn_metadata=model_input.attn_metadata,
  1080. intermediate_tensors=intermediate_tensors,
  1081. **multi_modal_kwargs,
  1082. )
  1083. # Compute the logits in the last pipeline stage.
  1084. if not get_pp_group().is_last_rank:
  1085. return hidden_or_intermediate_states
  1086. logits = self.model.compute_logits(hidden_or_intermediate_states,
  1087. model_input.sampling_metadata)
  1088. if not self.is_driver_worker:
  1089. return []
  1090. # Sample the next token.
  1091. output: SamplerOutput = self.model.sample(
  1092. logits=logits,
  1093. sampling_metadata=model_input.sampling_metadata,
  1094. )
  1095. if self.return_hidden_states:
  1096. # we only need to pass hidden states of most recent token
  1097. assert model_input.sampling_metadata is not None
  1098. indices = model_input.sampling_metadata.selected_token_indices
  1099. if model_input.is_prompt:
  1100. hidden_states = hidden_or_intermediate_states.index_select(
  1101. 0, indices)
  1102. elif decode_meta.use_cuda_graph:
  1103. hidden_states = hidden_or_intermediate_states[:len(indices)]
  1104. else:
  1105. hidden_states = hidden_or_intermediate_states
  1106. output.hidden_states = hidden_states
  1107. return [output]
  1108. class CUDAGraphRunner:
  1109. def __init__(self, model: nn.Module, backend_name: str):
  1110. self.model = model
  1111. self.backend_name = backend_name
  1112. self.input_buffers: Dict[str, torch.Tensor] = {}
  1113. self.output_buffers: Dict[str, torch.Tensor] = {}
  1114. self._graph: Optional[torch.cuda.CUDAGraph] = None
  1115. self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
  1116. self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
  1117. self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
  1118. self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
  1119. self.flashinfer_decode_wrapper: Optional[
  1120. CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
  1121. @property
  1122. def graph(self):
  1123. assert self._graph is not None
  1124. return self._graph
  1125. def capture(
  1126. self,
  1127. input_ids: torch.Tensor,
  1128. positions: torch.Tensor,
  1129. hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
  1130. torch.Tensor]],
  1131. intermediate_inputs: Optional[IntermediateTensors],
  1132. kv_caches: List[torch.Tensor],
  1133. attn_metadata: AttentionMetadata,
  1134. memory_pool: Optional[Tuple[int, int]],
  1135. stream: torch.cuda.Stream,
  1136. **kwargs,
  1137. ) -> Union[torch.Tensor, IntermediateTensors]:
  1138. assert self._graph is None
  1139. # Run the model a few times without capturing the graph.
  1140. # This is to make sure that the captured graph does not include the
  1141. # kernel launches for initial benchmarking (e.g., Triton autotune).
  1142. # Note one iteration is not enough for torch.jit.script
  1143. for _ in range(_NUM_WARMUP_ITERS):
  1144. self.model(
  1145. input_ids,
  1146. positions,
  1147. kv_caches,
  1148. attn_metadata,
  1149. intermediate_inputs,
  1150. **kwargs,
  1151. )
  1152. torch.cuda.synchronize()
  1153. # Capture the graph.
  1154. self._graph = torch.cuda.CUDAGraph()
  1155. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  1156. output_hidden_or_intermediate_states = self.model(
  1157. input_ids,
  1158. positions,
  1159. kv_caches,
  1160. attn_metadata,
  1161. intermediate_inputs,
  1162. **kwargs,
  1163. )
  1164. if hidden_or_intermediate_states is not None:
  1165. if get_pp_group().is_last_rank:
  1166. hidden_or_intermediate_states.copy_(
  1167. output_hidden_or_intermediate_states)
  1168. else:
  1169. for key in hidden_or_intermediate_states.tensors:
  1170. hidden_or_intermediate_states[key].copy_(
  1171. output_hidden_or_intermediate_states[key])
  1172. else:
  1173. hidden_or_intermediate_states = (
  1174. output_hidden_or_intermediate_states)
  1175. del output_hidden_or_intermediate_states
  1176. # make sure `output_hidden_states` is deleted
  1177. # in the graph's memory pool
  1178. gc.collect()
  1179. torch.cuda.synchronize()
  1180. # Save the input and output buffers.
  1181. if self.backend_name == "flashinfer":
  1182. self.input_buffers = {
  1183. "input_ids": input_ids,
  1184. "positions": positions,
  1185. "kv_caches": kv_caches,
  1186. "slot_mapping": attn_metadata.slot_mapping,
  1187. }
  1188. else:
  1189. self.input_buffers = {
  1190. "input_ids": input_ids,
  1191. "positions": positions,
  1192. "kv_caches": kv_caches,
  1193. "slot_mapping": attn_metadata.slot_mapping,
  1194. "seq_lens_tensor":
  1195. attn_metadata.decode_metadata.seq_lens_tensor,
  1196. "block_tables": attn_metadata.decode_metadata.block_tables,
  1197. }
  1198. if intermediate_inputs is not None:
  1199. self.input_buffers.update(intermediate_inputs.tensors)
  1200. if get_pp_group().is_last_rank:
  1201. self.output_buffers = {
  1202. "hidden_states": hidden_or_intermediate_states
  1203. }
  1204. else:
  1205. self.output_buffers = hidden_or_intermediate_states
  1206. return hidden_or_intermediate_states
  1207. def forward(
  1208. self,
  1209. input_ids: torch.Tensor,
  1210. positions: torch.Tensor,
  1211. kv_caches: List[torch.Tensor],
  1212. attn_metadata: AttentionMetadata,
  1213. intermediate_tensors: Optional[IntermediateTensors],
  1214. **kwargs,
  1215. ) -> torch.Tensor:
  1216. # KV caches are fixed tensors, so we don't need to copy them.
  1217. del kv_caches
  1218. # Copy the input tensors to the input buffers.
  1219. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  1220. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  1221. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  1222. non_blocking=True)
  1223. if self.backend_name != "flashinfer":
  1224. self.input_buffers["seq_lens_tensor"].copy_(
  1225. attn_metadata.decode_metadata.seq_lens_tensor,
  1226. non_blocking=True)
  1227. self.input_buffers["block_tables"].copy_(
  1228. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  1229. if intermediate_tensors is not None:
  1230. for key in intermediate_tensors.tensors:
  1231. self.input_buffers[key].copy_(intermediate_tensors[key],
  1232. non_blocking=True)
  1233. # Run the graph.
  1234. self.graph.replay()
  1235. # Return the output tensor.
  1236. if get_pp_group().is_last_rank:
  1237. return self.output_buffers["hidden_states"]
  1238. return self.output_buffers
  1239. def __call__(self, *args, **kwargs):
  1240. return self.forward(*args, **kwargs)
  1241. def _get_graph_batch_size(batch_size: int) -> int:
  1242. """Returns the padded batch size given actual batch size.
  1243. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  1244. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  1245. """
  1246. if batch_size <= 2:
  1247. return batch_size
  1248. elif batch_size <= 4:
  1249. return 4
  1250. else:
  1251. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  1252. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  1253. def _is_block_tables_empty(block_tables: Union[None, Dict]):
  1254. """
  1255. Check if block_tables is None or a dictionary with all None values.
  1256. """
  1257. if block_tables is None:
  1258. return True
  1259. if isinstance(block_tables, dict) and all(
  1260. value is None for value in block_tables.values()):
  1261. return True
  1262. return False