model_runner.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118
  1. import gc
  2. import time
  3. import warnings
  4. from collections import defaultdict
  5. from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from loguru import logger
  10. from aphrodite.attention import AttentionMetadata, get_attn_backend
  11. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  12. LoRAConfig, ModelConfig, ParallelConfig,
  13. SchedulerConfig, VisionLanguageConfig)
  14. from aphrodite.common.sampling_params import SamplingParams
  15. from aphrodite.common.sequence import (SamplerOutput, SequenceData,
  16. SequenceGroupMetadata)
  17. from aphrodite.common.utils import (CudaMemoryProfiler,
  18. get_kv_cache_torch_dtype, is_hip,
  19. is_pin_memory_available,
  20. make_tensor_with_pad)
  21. from aphrodite.distributed import broadcast_tensor_dict
  22. from aphrodite.distributed.parallel_state import graph_capture
  23. from aphrodite.distributed.parallel_state import (
  24. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  25. from aphrodite.lora.layers import LoRAMapping
  26. from aphrodite.lora.request import LoRARequest
  27. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  28. from aphrodite.modeling import SamplingMetadata
  29. from aphrodite.modeling.model_loader import get_model
  30. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  31. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  32. _PAD_SLOT_ID = -1
  33. LORA_WARMUP_RANK = 8
  34. _BATCH_SIZE_ALIGNMENT = 8
  35. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  36. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  37. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  38. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  39. ]
  40. _NUM_WARMUP_ITERS = 2
  41. class ModelInput(NamedTuple):
  42. input_tokens: torch.Tensor
  43. input_positions: torch.Tensor
  44. attn_metadata: Optional[AttentionMetadata]
  45. seq_lens: List[int]
  46. query_lens: List[int]
  47. lora_mapping: Optional[LoRAMapping]
  48. lora_requests: Set[LoRARequest]
  49. multi_modal_kwargs: Dict[str, torch.Tensor]
  50. slot_mapping: torch.Tensor
  51. num_prefill_tokens: int
  52. num_decode_tokens: int
  53. num_prefills: int
  54. @classmethod
  55. def empty(cls, device):
  56. return ModelInput(
  57. input_tokens=torch.empty(0, device=device),
  58. input_positions=torch.empty(0, device=device),
  59. attn_metadata=None,
  60. seq_lens=[],
  61. query_lens=[],
  62. lora_mapping=None,
  63. lora_requests=set(),
  64. multi_modal_kwargs={},
  65. slot_mapping=torch.empty(0, device=device),
  66. num_prefill_tokens=0,
  67. num_decode_tokens=0,
  68. num_prefills=0,
  69. )
  70. class ModelRunner:
  71. def __init__(
  72. self,
  73. model_config: ModelConfig,
  74. parallel_config: ParallelConfig,
  75. scheduler_config: SchedulerConfig,
  76. device_config: DeviceConfig,
  77. cache_config: CacheConfig,
  78. load_config: LoadConfig,
  79. lora_config: Optional[LoRAConfig],
  80. kv_cache_dtype: Optional[str] = "auto",
  81. is_driver_worker: bool = False,
  82. vision_language_config: Optional[VisionLanguageConfig] = None,
  83. ):
  84. self.model_config = model_config
  85. self.parallel_config = parallel_config
  86. self.scheduler_config = scheduler_config
  87. self.device_config = device_config
  88. self.cache_config = cache_config
  89. self.lora_config = lora_config
  90. self.load_config = load_config
  91. self.is_driver_worker = is_driver_worker
  92. self.vision_language_config = vision_language_config
  93. self.device = self.device_config.device
  94. self.pin_memory = is_pin_memory_available()
  95. self.kv_cache_dtype = kv_cache_dtype
  96. self.sliding_window = model_config.get_sliding_window()
  97. self.block_size = cache_config.block_size
  98. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  99. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  100. self.graph_memory_pool: Optional[Tuple[
  101. int, int]] = None # Set during graph capture.
  102. # When using CUDA graph, the input block tables must be padded to
  103. # max_seq_len_to_capture. However, creating the block table in
  104. # Python can be expensive. To optimize this, we cache the block table
  105. # in numpy and only copy the actual input content at every iteration.
  106. # The shape of the cached block table will be
  107. # (max batch size to capture, max context len to capture / block size).
  108. self.graph_block_tables = np.zeros(
  109. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  110. dtype=np.int32)
  111. self.attn_backend = get_attn_backend(
  112. self.model_config.get_num_attention_heads(self.parallel_config),
  113. self.model_config.get_head_size(),
  114. self.model_config.get_num_kv_heads(self.parallel_config),
  115. self.model_config.get_sliding_window(),
  116. self.model_config.dtype,
  117. self.kv_cache_dtype,
  118. self.block_size,
  119. )
  120. # Create processor for multi-modal data
  121. if self.vision_language_config is not None:
  122. self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
  123. .create_input_processor(
  124. self.model_config,
  125. self.vision_language_config,
  126. )
  127. else:
  128. self.multi_modal_input_processor = None
  129. # Lazy initialization
  130. self.model: nn.Module # Set after load_model
  131. # Set if the backend is flashinfer.
  132. self.flashinfer_workspace_buffer: torch.Tensor
  133. # Set after load_model.
  134. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  135. def load_model(self) -> None:
  136. with CudaMemoryProfiler() as m:
  137. # measure the time it takes to load the model
  138. start_time = time.time()
  139. self.model = get_model(
  140. model_config=self.model_config,
  141. device_config=self.device_config,
  142. load_config=self.load_config,
  143. lora_config=self.lora_config,
  144. vision_language_config=self.vision_language_config,
  145. parallel_config=self.parallel_config,
  146. scheduler_config=self.scheduler_config,
  147. cache_config=self.cache_config,
  148. )
  149. end_time = time.time()
  150. self.model_memory_usage = m.consumed_memory
  151. tp = get_tensor_model_parallel_world_size()
  152. rank = get_tensor_model_parallel_rank()
  153. total_time = end_time - start_time
  154. if tp > 1:
  155. logger.info(
  156. f"Rank {rank}: Model weights loaded in {total_time:.2f} secs.")
  157. if rank == 0:
  158. logger.info(
  159. "Memory usage: "
  160. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} ="
  161. f" {self.model_memory_usage * tp / float(2**30):.2f} GiB")
  162. else:
  163. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  164. logger.info("Memory usage: "
  165. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  166. if self.lora_config:
  167. assert hasattr(self.model, "supported_lora_modules"
  168. ) and self.model.supported_lora_modules, (
  169. "Model does not support LoRA")
  170. assert hasattr(
  171. self.model,
  172. "embedding_modules"), "Model does not have embedding_modules"
  173. assert hasattr(self.model, "embedding_padding_modules"
  174. ), "Model does not have embedding_padding_modules"
  175. self.lora_manager = LRUCacheWorkerLoRAManager(
  176. self.scheduler_config.max_num_seqs,
  177. self.scheduler_config.max_num_batched_tokens,
  178. self.vocab_size,
  179. self.lora_config,
  180. self.device,
  181. self.model.embedding_modules,
  182. self.model.embedding_padding_modules,
  183. max_position_embeddings=self.model.config.
  184. max_position_embeddings,
  185. )
  186. self.model = self.lora_manager.create_lora_manager(self.model)
  187. if self.kv_cache_dtype == "fp8" and is_hip():
  188. # Currently only ROCm accepts kv-cache scaling factors
  189. # via quantization_param_path and this will be deprecated
  190. # in the future.
  191. if self.model_config.quantization_param_path is not None:
  192. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  193. warnings.warn(
  194. "Loading kv cache scaling factor from JSON is "
  195. "deprecated and will be removed. Please include "
  196. "kv cache scaling factors in the model checkpoint.",
  197. FutureWarning,
  198. stacklevel=2)
  199. self.model.load_kv_cache_scales(
  200. self.model_config.quantization_param_path)
  201. logger.info(
  202. "Loaded KV cache scaling factors from ",
  203. f"{self.model_config.quantization_param_path}")
  204. else:
  205. raise RuntimeError(
  206. "Using FP8 KV cache and scaling factors provided but "
  207. f"model {self.model.__class__} does not support loading"
  208. " scaling factors.", )
  209. else:
  210. logger.warning(
  211. "Using FP8 KV cache but no scaling factors "
  212. "provided. Defaulting to scaling factors of 1.0. "
  213. "This may lead to less accurate results!")
  214. def save_sharded_state(
  215. self,
  216. path: str,
  217. pattern: Optional[str] = None,
  218. max_size: Optional[int] = None,
  219. ) -> None:
  220. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  221. ShardedStateLoader.save_model(
  222. self.model,
  223. path,
  224. pattern=pattern,
  225. max_size=max_size,
  226. )
  227. def save_tensorized_model(
  228. self,
  229. tensorizer_config: TensorizerConfig,
  230. ) -> None:
  231. from aphrodite.modeling.model_loader.loader import TensorizerLoader
  232. TensorizerLoader.save_model(
  233. self.model,
  234. tensorizer_config=tensorizer_config,
  235. )
  236. def get_max_block_per_batch(self) -> int:
  237. block_size = self.block_size
  238. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  239. def _prepare_model_input(
  240. self,
  241. seq_group_metadata_list: List[SequenceGroupMetadata],
  242. ) -> ModelInput:
  243. """Prepare the model input based on a given sequence group.
  244. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  245. The result tensors and data structure also batches input in prefill
  246. -> decode order. For example,
  247. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  248. - input_tokens[num_prefill_tokens:] contains decode tokens.
  249. If cuda graph is required, this API automatically pads inputs.
  250. """
  251. input_tokens: List[int] = []
  252. input_positions: List[int] = []
  253. slot_mapping: List[int] = []
  254. lora_index_mapping: List[int] = []
  255. lora_prompt_mapping: List[int] = []
  256. lora_requests: Set[LoRARequest] = set()
  257. seq_lens: List[int] = []
  258. prefill_seq_lens: List[int] = []
  259. decode_seq_lens: List[int] = []
  260. context_lens: List[int] = []
  261. query_lens: List[int] = []
  262. block_tables: List[List[int]] = []
  263. multi_modal_kwargs_list: Dict[str,
  264. List[torch.Tensor]] = defaultdict(list)
  265. decode_only = True
  266. num_prefills = 0
  267. num_prefill_tokens = 0
  268. num_decode_tokens = 0
  269. # The following fields are only for flashinfer
  270. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  271. # for the precise definition of the following fields.
  272. # An example:
  273. # request 1, page indices [0, 5, 8]
  274. # request 2, page indices [1, 6, 7]
  275. # request 3, page indices [3, 4]
  276. # paged_kv_indices is a concatenation of page indices of all requests:
  277. # [0, 5, 8, 1, 6, 7, 3, 4]
  278. # paged_kv_indptr is used to index into paged_kv_indices:
  279. # [0, 3, 6, 8]
  280. paged_kv_indices: List[int] = []
  281. # 0 at the beginning of paged_kv_indptr indicates the start of the
  282. # first request’s page indices in the paged_kv_indices list.
  283. paged_kv_indptr: List[int] = [0]
  284. # paged_kv_last_page_len is the length of the last page of each request
  285. paged_kv_last_page_len: List[int] = []
  286. if len(seq_group_metadata_list) == 0:
  287. return ModelInput.empty(self.device)
  288. if self.sliding_window is not None:
  289. sliding_window_blocks = (self.sliding_window + self.block_size -
  290. 1) // self.block_size
  291. block_aligned_sliding_window = \
  292. sliding_window_blocks * self.block_size
  293. for seq_group_metadata in seq_group_metadata_list:
  294. seq_ids = list(seq_group_metadata.seq_data.keys())
  295. is_prompt = seq_group_metadata.is_prompt
  296. for seq_id in seq_ids:
  297. computed_block_nums = seq_group_metadata.computed_block_nums
  298. if (self.scheduler_config is not None
  299. and self.scheduler_config.chunked_prefill_enabled
  300. and not (computed_block_nums is None
  301. or computed_block_nums == [])):
  302. raise RuntimeError(
  303. "chunked prefill cannot be used with prefix caching "
  304. "now.")
  305. seq_data = seq_group_metadata.seq_data[seq_id]
  306. if is_prompt:
  307. context_len = seq_data.get_num_computed_tokens()
  308. else:
  309. # get_num_computed_tokens is incorrect for spec decoding.
  310. # So, we should have a special logic here.
  311. # TODO: Fix it.
  312. context_len = seq_data.get_len() - 1
  313. seq_len = min(
  314. seq_data.get_len(),
  315. context_len + seq_group_metadata.token_chunk_size)
  316. if is_prompt:
  317. tokens = seq_data.get_token_ids()[context_len:seq_len]
  318. else:
  319. # Optimization. get_token_ids requires the entire copy of
  320. # tokens.
  321. tokens = [seq_data.get_last_token_id()]
  322. # Prefix cache was hit.
  323. # Prefix is not supported with sliding_window
  324. prefix_cache_hit = (computed_block_nums is not None
  325. and len(computed_block_nums) > 0
  326. and self.sliding_window is None
  327. and is_prompt)
  328. # These are seq_len/context_len capped to the sliding window.
  329. # They are passed to decode kernel.
  330. # We still need original seq_len/context_len to compute slot
  331. # mapping (and input position) below.
  332. curr_sliding_window_blocks = None
  333. sliding_seq_len = seq_len
  334. sliding_context_len = context_len
  335. # TODO: This is a hack to make sliding window work with
  336. # paged attn. We can remove it if we make paged attn kernel
  337. # to properly handle slinding window attn.
  338. if (self.sliding_window is not None and not is_prompt):
  339. curr_sliding_window_blocks = sliding_window_blocks
  340. if self.scheduler_config.use_v2_block_manager:
  341. # number of elements in last block
  342. suff_len = seq_len % self.block_size
  343. sliding_seq_len = min(
  344. seq_len, block_aligned_sliding_window + suff_len)
  345. if suff_len > 0:
  346. curr_sliding_window_blocks += 1
  347. else:
  348. sliding_seq_len = min(seq_len, self.sliding_window)
  349. sliding_context_len = sliding_seq_len - 1
  350. # TODO: Combine chunked prefill and prefix caching by
  351. # only allowing multiple of block_size chunk size.
  352. # NOTE: This only works for oooooooxxx style attention.
  353. if prefix_cache_hit:
  354. assert computed_block_nums is not None
  355. context_len = len(computed_block_nums) * self.block_size
  356. tokens = tokens[context_len:]
  357. # need to think what to set it to when we have both sliding
  358. # window and prefix caching...
  359. assert self.sliding_window is None, \
  360. "Prefix caching is not supported with sliding window"
  361. sliding_context_len = context_len
  362. if self.attn_backend.get_name() == "flash-attn":
  363. # NOTE: For flash-attn, the block table should
  364. # include the entries for the incoming prefill tokens.
  365. # TODO: This is a temporary fix. We should
  366. # provide a unified interface for different backends.
  367. block_table = seq_group_metadata.block_tables[seq_id]
  368. else:
  369. block_table = computed_block_nums
  370. elif (self.scheduler_config.chunked_prefill_enabled
  371. or not is_prompt):
  372. if seq_group_metadata.block_tables is not None:
  373. # chunked prefill or decode
  374. block_table = seq_group_metadata.block_tables[seq_id]
  375. if curr_sliding_window_blocks is not None:
  376. block_table = block_table[
  377. -curr_sliding_window_blocks:]
  378. if self.attn_backend.get_name() == "flashinfer":
  379. paged_kv_indices.extend(block_table)
  380. paged_kv_indptr.append(paged_kv_indptr[-1] +
  381. len(block_table))
  382. last_page_len = seq_data.get_len(
  383. ) % self.block_size
  384. if last_page_len == 0:
  385. last_page_len = self.block_size
  386. paged_kv_last_page_len.append(last_page_len)
  387. else:
  388. # Only happens when memory profiling runs.
  389. block_table = []
  390. else:
  391. # Prefill without chunked prefill or memory profiling.
  392. block_table = []
  393. block_tables.append(block_table)
  394. seq_lens.append(sliding_seq_len)
  395. context_lens.append(sliding_context_len)
  396. query_len = sliding_seq_len - sliding_context_len
  397. query_lens.append(query_len)
  398. input_tokens.extend(tokens)
  399. input_positions.extend(list(range(context_len, seq_len)))
  400. lora_id = seq_group_metadata.lora_int_id
  401. if is_prompt:
  402. assert len(seq_ids) == 1
  403. num_prefills += 1
  404. num_prefill_tokens += len(tokens)
  405. decode_only = False
  406. prefill_seq_lens.append(seq_len)
  407. else:
  408. assert query_len == 1, (
  409. "seq_len: {}, context_len: {}, query_len: {}".format(
  410. seq_len, context_len, query_len))
  411. num_decode_tokens += query_len
  412. decode_seq_lens.append(sliding_seq_len)
  413. if lora_id > 0:
  414. lora_requests.add(seq_group_metadata.lora_request)
  415. lora_index_mapping += [lora_id] * query_len
  416. lora_prompt_mapping.extend(
  417. [lora_id] *
  418. (query_len if seq_group_metadata.sampling_params
  419. and seq_group_metadata.sampling_params.prompt_logprobs
  420. is not None else 1))
  421. mm_data = seq_group_metadata.multi_modal_data
  422. if mm_data is not None:
  423. # Process multi-modal data
  424. if self.multi_modal_input_processor is None:
  425. raise ValueError(
  426. "Multi-modal inputs are only supported by "
  427. "vision language models.")
  428. mm_kwargs = self.multi_modal_input_processor(mm_data)
  429. for k, v in mm_kwargs.items():
  430. multi_modal_kwargs_list[k].append(v)
  431. if _is_block_tables_empty(seq_group_metadata.block_tables):
  432. # During memory profiling, the block tables are not
  433. # initialized yet. In this case, we just use a dummy
  434. # slot mapping.
  435. # In embeddings, the block tables are {seq_id: None}.
  436. slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  437. continue
  438. # Compute the slot mapping.
  439. block_table = seq_group_metadata.block_tables[seq_id]
  440. # Mask the [0, start_idx) tokens of the prompt with
  441. # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
  442. # sliding_window). For example, if the prompt len is 10,
  443. # sliding window is 8, and block size is 4, the first two
  444. # tokens are masked and the slot mapping will be
  445. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  446. start_idx = 0
  447. if self.sliding_window is not None:
  448. if is_prompt:
  449. assert self.scheduler_config.use_v2_block_manager \
  450. or context_len == 0, (
  451. "Prefix caching is currently not supported with "
  452. "sliding window attention in V1 block manager")
  453. # It is an optimization. When it is decoding, it is always
  454. # 0. When prefill, we use it to not write slots to kv cache
  455. # to save memory.
  456. start_idx = max(0, query_len - self.sliding_window)
  457. for i in range(context_len, seq_len):
  458. if i < start_idx:
  459. slot_mapping.append(_PAD_SLOT_ID)
  460. continue
  461. block_number = block_table[i // self.block_size]
  462. block_offset = i % self.block_size
  463. slot = block_number * self.block_size + block_offset
  464. slot_mapping.append(slot)
  465. batch_size = len(input_tokens)
  466. max_query_len = max(query_lens)
  467. max_prefill_seq_len = max(prefill_seq_lens, default=0)
  468. max_decode_seq_len = max(decode_seq_lens, default=0)
  469. # If cuda graph can be used, pad tensors accordingly.
  470. # See `capture_model` API for more details.
  471. # Aphrodite uses cuda graph only for decoding requests.
  472. use_captured_graph = (
  473. decode_only and not self.model_config.enforce_eager
  474. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  475. and max_decode_seq_len <= self.max_seq_len_to_capture)
  476. if use_captured_graph:
  477. graph_batch_size = _get_graph_batch_size(batch_size)
  478. assert graph_batch_size >= batch_size
  479. for _ in range(graph_batch_size - batch_size):
  480. input_tokens.append(0)
  481. input_positions.append(0)
  482. slot_mapping.append(_PAD_SLOT_ID)
  483. seq_lens.append(1)
  484. block_tables.append([])
  485. lora_index_mapping.append(0)
  486. batch_size = graph_batch_size
  487. num_decode_tokens = batch_size
  488. if use_captured_graph:
  489. # The shape of graph_block_tables is
  490. # [max batch size, max context len // block size].
  491. input_block_tables = self.graph_block_tables[:batch_size]
  492. for i, block_table in enumerate(block_tables):
  493. if block_table:
  494. input_block_tables[i, :len(block_table)] = block_table
  495. block_tables = torch.tensor(input_block_tables, device=self.device)
  496. else:
  497. max_block_table_len = max(
  498. len(block_table) for block_table in block_tables)
  499. block_tables = make_tensor_with_pad(
  500. block_tables,
  501. max_len=max_block_table_len,
  502. pad=0,
  503. dtype=torch.int,
  504. device=self.device,
  505. )
  506. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  507. seq_lens_tensor = torch.tensor(seq_lens,
  508. dtype=torch.int,
  509. device=self.device)
  510. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  511. dtype=torch.int32,
  512. device=self.device)
  513. torch.cumsum(seq_lens_tensor,
  514. dim=0,
  515. dtype=seq_start_loc.dtype,
  516. out=seq_start_loc[1:])
  517. input_tokens_tensor = torch.tensor(input_tokens,
  518. dtype=torch.long,
  519. device=self.device)
  520. input_positions_tensor = torch.tensor(input_positions,
  521. dtype=torch.long,
  522. device=self.device)
  523. slot_mapping_tensor = torch.tensor(slot_mapping,
  524. dtype=torch.long,
  525. device=self.device)
  526. if self.attn_backend.get_name() == "flashinfer":
  527. if not hasattr(self, "flashinfer_workspace_buffer"):
  528. # Allocate 16MB workspace buffer
  529. # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
  530. self.flashinfer_workspace_buffer = torch.empty(
  531. 16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
  532. paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
  533. dtype=torch.int,
  534. device=self.device)
  535. paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
  536. dtype=torch.int,
  537. device=self.device)
  538. paged_kv_last_page_len_tensor = torch.tensor(
  539. paged_kv_last_page_len, dtype=torch.int, device=self.device)
  540. kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
  541. self.model_config.dtype)
  542. attn_metadata = self.attn_backend.make_metadata(
  543. num_prefills=num_prefills,
  544. slot_mapping=slot_mapping_tensor,
  545. num_prefill_tokens=num_prefill_tokens,
  546. num_decode_tokens=num_decode_tokens,
  547. use_cuda_graph=False,
  548. max_prefill_seq_len=max_prefill_seq_len,
  549. block_tables=block_tables,
  550. workspace_buffer=self.flashinfer_workspace_buffer,
  551. paged_kv_indptr=paged_kv_indptr_tensor,
  552. paged_kv_indices=paged_kv_indices_tensor,
  553. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  554. num_qo_heads=self.model_config.get_num_attention_heads(
  555. self.parallel_config),
  556. num_kv_heads=self.model_config.get_num_kv_heads(
  557. self.parallel_config),
  558. head_dim=self.model_config.get_head_size(),
  559. page_size=16,
  560. seq_start_loc=seq_start_loc,
  561. data_type=kv_cache_dtype)
  562. else:
  563. context_lens_tensor = torch.tensor(context_lens,
  564. dtype=torch.int,
  565. device=self.device)
  566. query_lens_tensor = torch.tensor(query_lens,
  567. dtype=torch.long,
  568. device=self.device)
  569. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  570. dtype=torch.int32,
  571. device=self.device)
  572. torch.cumsum(query_lens_tensor,
  573. dim=0,
  574. dtype=query_start_loc.dtype,
  575. out=query_start_loc[1:])
  576. attn_metadata = self.attn_backend.make_metadata(
  577. num_prefills=num_prefills,
  578. slot_mapping=slot_mapping_tensor,
  579. num_prefill_tokens=num_prefill_tokens,
  580. num_decode_tokens=num_decode_tokens,
  581. seq_lens=seq_lens,
  582. seq_lens_tensor=seq_lens_tensor,
  583. max_query_len=max_query_len,
  584. max_prefill_seq_len=max_prefill_seq_len,
  585. max_decode_seq_len=max_decode_seq_len,
  586. query_start_loc=query_start_loc,
  587. seq_start_loc=seq_start_loc,
  588. context_lens_tensor=context_lens_tensor,
  589. block_tables=block_tables,
  590. use_cuda_graph=use_captured_graph,
  591. )
  592. if self.lora_config:
  593. lora_mapping = LoRAMapping(
  594. lora_index_mapping,
  595. lora_prompt_mapping,
  596. )
  597. else:
  598. lora_mapping = None
  599. multi_modal_kwargs = {
  600. k: torch.cat(v, dim=0).to(self.device)
  601. for k, v in multi_modal_kwargs_list.items()
  602. }
  603. return ModelInput(
  604. input_tokens=input_tokens_tensor,
  605. input_positions=input_positions_tensor,
  606. attn_metadata=attn_metadata,
  607. seq_lens=seq_lens,
  608. query_lens=query_lens,
  609. lora_mapping=lora_mapping,
  610. lora_requests=lora_requests,
  611. multi_modal_kwargs=multi_modal_kwargs,
  612. slot_mapping=slot_mapping_tensor,
  613. num_prefill_tokens=num_prefill_tokens,
  614. num_decode_tokens=num_decode_tokens,
  615. num_prefills=num_prefills,
  616. )
  617. def prepare_input_tensors(
  618. self,
  619. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  620. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  621. Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
  622. if self.is_driver_worker:
  623. assert seq_group_metadata_list is not None
  624. # Prepare input tensors.
  625. (
  626. input_tokens,
  627. input_positions,
  628. attn_metadata,
  629. seq_lens,
  630. query_lens,
  631. lora_mapping,
  632. lora_requests,
  633. multi_modal_kwargs,
  634. slot_mapping,
  635. num_prefill_tokens,
  636. num_decode_tokens,
  637. num_prefills,
  638. ) = self._prepare_model_input(seq_group_metadata_list)
  639. sampling_metadata = SamplingMetadata.prepare(
  640. seq_group_metadata_list, seq_lens, query_lens, self.device,
  641. self.pin_memory)
  642. metadata_dict = {
  643. "input_tokens": input_tokens,
  644. "input_positions": input_positions,
  645. "selected_token_indices":
  646. sampling_metadata.selected_token_indices,
  647. "lora_requests": lora_requests,
  648. "lora_mapping": lora_mapping,
  649. "multi_modal_kwargs": multi_modal_kwargs,
  650. "num_prefill_tokens": num_prefill_tokens,
  651. "num_decode_tokens": num_decode_tokens,
  652. "slot_mapping": slot_mapping,
  653. "num_prefills": num_prefills,
  654. }
  655. if attn_metadata:
  656. metadata_dict.update(attn_metadata.asdict_zerocopy())
  657. broadcast_tensor_dict(metadata_dict, src=0)
  658. else:
  659. metadata_dict = broadcast_tensor_dict(src=0)
  660. input_tokens = metadata_dict.pop("input_tokens")
  661. input_positions = metadata_dict.pop("input_positions")
  662. selected_token_indices = metadata_dict.pop(
  663. "selected_token_indices")
  664. lora_mapping = metadata_dict.pop("lora_mapping")
  665. lora_requests = metadata_dict.pop("lora_requests")
  666. multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
  667. if metadata_dict:
  668. attn_metadata = self.attn_backend.make_metadata(
  669. **metadata_dict)
  670. else:
  671. attn_metadata = None
  672. sampling_metadata = SamplingMetadata(
  673. seq_groups=None,
  674. selected_token_indices=selected_token_indices,
  675. categorized_sample_indices=None,
  676. num_prompts=0,
  677. )
  678. return (input_tokens, input_positions, attn_metadata,
  679. sampling_metadata, lora_requests, lora_mapping,
  680. multi_modal_kwargs)
  681. @torch.inference_mode()
  682. def execute_model(
  683. self,
  684. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  685. kv_caches: List[torch.Tensor],
  686. ) -> Optional[SamplerOutput]:
  687. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  688. lora_requests, lora_mapping, multi_modal_kwargs
  689. ) = self.prepare_input_tensors(seq_group_metadata_list)
  690. if self.lora_config:
  691. self.set_active_loras(lora_requests, lora_mapping)
  692. # Currently cuda graph is only supported by the decode phase.
  693. prefill_meta = attn_metadata.prefill_metadata
  694. decode_meta = attn_metadata.decode_metadata
  695. if prefill_meta is None and decode_meta.use_cuda_graph:
  696. graph_batch_size = input_tokens.shape[0]
  697. model_executable = self.graph_runners[graph_batch_size]
  698. else:
  699. model_executable = self.model
  700. hidden_states = model_executable(
  701. input_ids=input_tokens,
  702. positions=input_positions,
  703. kv_caches=kv_caches,
  704. attn_metadata=attn_metadata,
  705. **multi_modal_kwargs,
  706. )
  707. # Compute the logits.
  708. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  709. # Only perform sampling in the driver worker.
  710. if not self.is_driver_worker:
  711. return None
  712. # Sample the next token.
  713. output = self.model.sample(
  714. logits=logits,
  715. sampling_metadata=sampling_metadata,
  716. )
  717. return output
  718. @torch.inference_mode()
  719. def profile_run(self) -> None:
  720. # Enable top-k sampling to reflect the accurate memory usage.
  721. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  722. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  723. max_num_seqs = self.scheduler_config.max_num_seqs
  724. # This represents the maximum number of different requests
  725. # that will have unique loras, an therefore the max amount of memory
  726. # consumption create dummy lora request copies from the lora request
  727. # passed in, which contains a lora from the lora warmup path.
  728. dummy_lora_requests = []
  729. dummy_lora_requests_per_seq = []
  730. if self.lora_config:
  731. assert self.lora_manager is not None
  732. with self.lora_manager.dummy_lora_cache():
  733. for idx in range(self.lora_config.max_loras):
  734. lora_id = idx + 1
  735. dummy_lora_request = LoRARequest(
  736. lora_name=f"warmup_{lora_id}",
  737. lora_int_id=lora_id,
  738. lora_local_path="/not/a/real/path",
  739. )
  740. self.lora_manager.add_dummy_lora(dummy_lora_request,
  741. rank=LORA_WARMUP_RANK)
  742. dummy_lora_requests.append(dummy_lora_request)
  743. dummy_lora_requests_per_seq = [
  744. dummy_lora_requests[idx % len(dummy_lora_requests)]
  745. for idx in range(max_num_seqs)
  746. ]
  747. # Profile memory usage with max_num_sequences sequences and the total
  748. # number of tokens equal to max_num_batched_tokens.
  749. seqs: List[SequenceGroupMetadata] = []
  750. # Additional GPU memory may be needed for vision encoding, which needs
  751. # to be accounted for when calculating the GPU blocks for
  752. # Aphrodite blocker manager.
  753. # To exercise the worst scenario for GPU memory consumption,
  754. # the number of seqs (batch_size) is chosen to maximize the number
  755. # of images processed.
  756. model_config = self.model_config
  757. vlm_config = self.vision_language_config
  758. if vlm_config:
  759. max_num_seqs = min(
  760. max_num_seqs,
  761. int(max_num_batched_tokens / vlm_config.image_feature_size))
  762. for group_id in range(max_num_seqs):
  763. seq_len = (max_num_batched_tokens // max_num_seqs +
  764. (group_id < max_num_batched_tokens % max_num_seqs))
  765. if vlm_config is None:
  766. seq_data = SequenceData([0] * seq_len)
  767. dummy_multi_modal_data = None
  768. else:
  769. seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
  770. .dummy_data_for_profiling(seq_len, model_config, vlm_config)
  771. seq = SequenceGroupMetadata(
  772. request_id=str(group_id),
  773. is_prompt=True,
  774. seq_data={group_id: seq_data},
  775. sampling_params=sampling_params,
  776. block_tables=None,
  777. lora_request=dummy_lora_requests_per_seq[group_id]
  778. if dummy_lora_requests_per_seq else None,
  779. multi_modal_data=dummy_multi_modal_data,
  780. )
  781. seqs.append(seq)
  782. # Run the model with the dummy inputs.
  783. num_layers = self.model_config.get_num_layers(self.parallel_config)
  784. kv_caches = [None] * num_layers
  785. self.execute_model(seqs, kv_caches)
  786. torch.cuda.synchronize()
  787. return
  788. def remove_all_loras(self):
  789. if not self.lora_manager:
  790. raise RuntimeError("LoRA is not enabled.")
  791. self.lora_manager.remove_all_loras()
  792. def set_active_loras(self, lora_requests: Set[LoRARequest],
  793. lora_mapping: LoRAMapping) -> None:
  794. if not self.lora_manager:
  795. raise RuntimeError("LoRA is not enabled.")
  796. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  797. def add_lora(self, lora_request: LoRARequest) -> bool:
  798. if not self.lora_manager:
  799. raise RuntimeError("LoRA is not enabled.")
  800. return self.lora_manager.add_lora(lora_request)
  801. def remove_lora(self, lora_id: int) -> bool:
  802. if not self.lora_manager:
  803. raise RuntimeError("LoRA is not enabled.")
  804. return self.lora_manager.remove_lora(lora_id)
  805. def list_loras(self) -> Set[int]:
  806. if not self.lora_manager:
  807. raise RuntimeError("LoRA is not enabled.")
  808. return self.lora_manager.list_loras()
  809. @torch.inference_mode()
  810. def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
  811. """Cuda graph capture a model.
  812. Note that CUDA graph's performance gain is negligible if number
  813. of batched tokens are larger than 200. And since CUDA graph
  814. requires fixed sized tensors, supporting large/variable batch
  815. size requires high GPU memory overhead. Thus, Aphrodite only captures
  816. decoding requests. Mixed batch (chunked prefill + decoding) or
  817. prefill requests are not captured.
  818. Since it is used for decoding-only, it assumes there's only 1 token
  819. per sequence in the batch.
  820. """
  821. assert not self.model_config.enforce_eager
  822. logger.info("Capturing the model for CUDA graphs. This may lead to "
  823. "unexpected consequences if the model is not static. To "
  824. "run the model in eager mode, set 'enforce_eager=True' or "
  825. "use '--enforce-eager' in the CLI.")
  826. logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
  827. "If you are running out of memory, consider decreasing "
  828. "`gpu_memory_utilization` or enforcing eager mode. "
  829. "You can also reduce the `max_num_seqs` as needed "
  830. "to decrease memory usage.")
  831. start_time = time.perf_counter()
  832. # Prepare dummy inputs. These will be reused for all batch sizes.
  833. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  834. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  835. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  836. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  837. slot_mapping.fill_(_PAD_SLOT_ID)
  838. seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  839. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  840. # Prepare buffer for outputs. These will be reused for all batch sizes.
  841. # It will be filled after the first graph capture.
  842. hidden_states: Optional[torch.Tensor] = None
  843. graph_batch_size = _get_graph_batch_size(
  844. self.scheduler_config.max_num_seqs)
  845. batch_size_capture_list = [
  846. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  847. ]
  848. with graph_capture() as graph_capture_context:
  849. # NOTE: Capturing the largest batch size first may help reduce the
  850. # memory usage of CUDA graph.
  851. for batch_size in reversed(batch_size_capture_list):
  852. # Create dummy attn_metadata.
  853. attn_metadata = self.attn_backend.make_metadata(
  854. num_prefills=0,
  855. num_prefill_tokens=0,
  856. num_decode_tokens=batch_size,
  857. slot_mapping=slot_mapping[:batch_size],
  858. seq_lens=None,
  859. seq_lens_tensor=seq_lens[:batch_size],
  860. max_query_len=None,
  861. max_prefill_seq_len=0,
  862. max_decode_seq_len=self.max_seq_len_to_capture,
  863. query_start_loc=None,
  864. seq_start_loc=None,
  865. context_lens_tensor=None,
  866. block_tables=block_tables[:batch_size],
  867. use_cuda_graph=True,
  868. )
  869. if self.lora_config:
  870. lora_mapping = LoRAMapping(
  871. [0] * batch_size,
  872. [0] * batch_size,
  873. )
  874. self.set_active_loras(set(), lora_mapping)
  875. graph_runner = CUDAGraphRunner(self.model)
  876. hidden_states = graph_runner.capture(
  877. input_tokens[:batch_size],
  878. input_positions[:batch_size],
  879. hidden_states[:batch_size]
  880. if hidden_states is not None else None,
  881. kv_caches,
  882. attn_metadata,
  883. memory_pool=self.graph_memory_pool,
  884. stream=graph_capture_context.stream,
  885. )
  886. self.graph_memory_pool = graph_runner.graph.pool()
  887. self.graph_runners[batch_size] = graph_runner
  888. end_time = time.perf_counter()
  889. elapsed_time = end_time - start_time
  890. # This usually takes < 10 seconds.
  891. logger.info(f"Graph capturing finished in {elapsed_time} secs.")
  892. @property
  893. def vocab_size(self) -> int:
  894. return self.model_config.get_vocab_size()
  895. class CUDAGraphRunner:
  896. def __init__(self, model: nn.Module):
  897. self.model = model
  898. self.input_buffers: Dict[str, torch.Tensor] = {}
  899. self.output_buffers: Dict[str, torch.Tensor] = {}
  900. self._graph: Optional[torch.cuda.CUDAGraph] = None
  901. @property
  902. def graph(self):
  903. assert self._graph is not None
  904. return self._graph
  905. def capture(
  906. self,
  907. input_ids: torch.Tensor,
  908. positions: torch.Tensor,
  909. hidden_states: Optional[torch.Tensor],
  910. kv_caches: List[torch.Tensor],
  911. attn_metadata: AttentionMetadata,
  912. memory_pool: Optional[Tuple[int, int]],
  913. stream: torch.cuda.Stream,
  914. **kwargs,
  915. ) -> torch.Tensor:
  916. assert self._graph is None
  917. # Run the model a few times without capturing the graph.
  918. # This is to make sure that the captured graph does not include the
  919. # kernel launches for initial benchmarking (e.g., Triton autotune).
  920. # Note one iteration is not enough for torch.jit.script
  921. for _ in range(_NUM_WARMUP_ITERS):
  922. self.model(
  923. input_ids,
  924. positions,
  925. kv_caches,
  926. attn_metadata,
  927. **kwargs,
  928. )
  929. torch.cuda.synchronize()
  930. # Capture the graph.
  931. self._graph = torch.cuda.CUDAGraph()
  932. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  933. output_hidden_states = self.model(
  934. input_ids,
  935. positions,
  936. kv_caches,
  937. attn_metadata,
  938. **kwargs,
  939. )
  940. if hidden_states is not None:
  941. hidden_states.copy_(output_hidden_states)
  942. else:
  943. hidden_states = output_hidden_states
  944. del output_hidden_states
  945. # make sure `output_hidden_states` is deleted
  946. # in the graph's memory pool
  947. gc.collect()
  948. torch.cuda.synchronize()
  949. # Save the input and output buffers.
  950. self.input_buffers = {
  951. "input_ids": input_ids,
  952. "positions": positions,
  953. "kv_caches": kv_caches,
  954. "slot_mapping": attn_metadata.slot_mapping,
  955. "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
  956. "block_tables": attn_metadata.decode_metadata.block_tables,
  957. }
  958. self.output_buffers = {"hidden_states": hidden_states}
  959. return hidden_states
  960. def forward(
  961. self,
  962. input_ids: torch.Tensor,
  963. positions: torch.Tensor,
  964. kv_caches: List[torch.Tensor],
  965. attn_metadata: AttentionMetadata,
  966. **kwargs,
  967. ) -> torch.Tensor:
  968. # KV caches are fixed tensors, so we don't need to copy them.
  969. del kv_caches
  970. # Copy the input tensors to the input buffers.
  971. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  972. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  973. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  974. non_blocking=True)
  975. self.input_buffers["seq_lens_tensor"].copy_(
  976. attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
  977. self.input_buffers["block_tables"].copy_(
  978. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  979. # Run the graph.
  980. self.graph.replay()
  981. # Return the output tensor.
  982. return self.output_buffers["hidden_states"]
  983. def __call__(self, *args, **kwargs):
  984. return self.forward(*args, **kwargs)
  985. def _get_graph_batch_size(batch_size: int) -> int:
  986. """Returns the padded batch size given actual batch size.
  987. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  988. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  989. """
  990. if batch_size <= 2:
  991. return batch_size
  992. elif batch_size <= 4:
  993. return 4
  994. else:
  995. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  996. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  997. def _is_block_tables_empty(block_tables: Union[None, Dict]):
  998. """
  999. Check if block_tables is None or a dictionary with all None values.
  1000. """
  1001. if block_tables is None:
  1002. return True
  1003. if isinstance(block_tables, dict) and all(
  1004. value is None for value in block_tables.values()):
  1005. return True
  1006. return False