model_runner.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074
  1. import time
  2. from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from loguru import logger
  7. from aphrodite.attention import AttentionMetadata, get_attn_backend
  8. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  9. LoRAConfig, ModelConfig, ParallelConfig,
  10. SchedulerConfig, VisionLanguageConfig)
  11. from aphrodite.common.sampling_params import SamplingParams
  12. from aphrodite.common.sequence import (MultiModalData, SamplerOutput,
  13. SequenceData, SequenceGroupMetadata)
  14. from aphrodite.common.utils import (CudaMemoryProfiler,
  15. get_kv_cache_torch_dtype, is_hip,
  16. is_pin_memory_available,
  17. make_tensor_with_pad)
  18. from aphrodite.distributed import broadcast_tensor_dict
  19. from aphrodite.distributed.communication_op import graph_capture
  20. from aphrodite.distributed.parallel_state import \
  21. get_tensor_model_parallel_world_size
  22. from aphrodite.lora.layers import LoRAMapping
  23. from aphrodite.lora.request import LoRARequest
  24. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  25. from aphrodite.modeling import SamplingMetadata
  26. from aphrodite.modeling.model_loader import get_model
  27. _PAD_SLOT_ID = -1
  28. LORA_WARMUP_RANK = 8
  29. _BATCH_SIZE_ALIGNMENT = 8
  30. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  31. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  32. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  33. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  34. ]
  35. class ModelInput(NamedTuple):
  36. input_tokens: torch.Tensor
  37. input_positions: torch.Tensor
  38. attn_metadata: Optional[AttentionMetadata]
  39. seq_lens: List[int]
  40. query_lens: List[int]
  41. lora_mapping: Optional[LoRAMapping]
  42. lora_requests: Set[LoRARequest]
  43. multi_modal_input: Optional[torch.Tensor]
  44. slot_mapping: torch.Tensor
  45. num_prefill_tokens: int
  46. num_decode_tokens: int
  47. num_prefills: int
  48. @classmethod
  49. def empty(cls, device):
  50. return ModelInput(
  51. input_tokens=torch.empty(0, device=device),
  52. input_positions=torch.empty(0, device=device),
  53. attn_metadata=None,
  54. seq_lens=[],
  55. query_lens=[],
  56. lora_mapping=None,
  57. lora_requests=set(),
  58. multi_modal_input=None,
  59. slot_mapping=torch.empty(0, device=device),
  60. num_prefill_tokens=0,
  61. num_decode_tokens=0,
  62. num_prefills=0,
  63. )
  64. class PrepareDecodeMetadata(NamedTuple):
  65. input_tokens: List[int]
  66. input_positions: List[int]
  67. attn_metadata: Optional[AttentionMetadata]
  68. lora_index_mapping: List[int]
  69. lora_prompt_mapping: List[int]
  70. lora_requests: Set[LoRARequest]
  71. slot_mapping: List[int]
  72. @classmethod
  73. def empty(cls):
  74. return PrepareDecodeMetadata(
  75. input_tokens=[],
  76. input_positions=[],
  77. attn_metadata=None,
  78. lora_index_mapping=[],
  79. lora_prompt_mapping=[],
  80. lora_requests=set(),
  81. slot_mapping=[],
  82. )
  83. class ModelRunner:
  84. def __init__(
  85. self,
  86. model_config: ModelConfig,
  87. parallel_config: ParallelConfig,
  88. scheduler_config: SchedulerConfig,
  89. device_config: DeviceConfig,
  90. cache_config: CacheConfig,
  91. load_config: LoadConfig,
  92. lora_config: Optional[LoRAConfig],
  93. kv_cache_dtype: Optional[str] = "auto",
  94. is_driver_worker: bool = False,
  95. vision_language_config: Optional[VisionLanguageConfig] = None,
  96. ):
  97. self.model_config = model_config
  98. self.parallel_config = parallel_config
  99. self.scheduler_config = scheduler_config
  100. self.device_config = device_config
  101. self.cache_config = cache_config
  102. self.lora_config = lora_config
  103. self.load_config = load_config
  104. self.is_driver_worker = is_driver_worker
  105. self.vision_language_config = vision_language_config
  106. self.device = self.device_config.device
  107. self.pin_memory = is_pin_memory_available()
  108. self.kv_cache_dtype = kv_cache_dtype
  109. self.sliding_window = model_config.get_sliding_window()
  110. self.block_size = cache_config.block_size
  111. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  112. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  113. self.graph_memory_pool: Optional[Tuple[
  114. int, int]] = None # Set during graph capture.
  115. # When using CUDA graph, the input block tables must be padded to
  116. # max_seq_len_to_capture. However, creating the block table in
  117. # Python can be expensive. To optimize this, we cache the block table
  118. # in numpy and only copy the actual input content at every iteration.
  119. # The shape of the cached block table will be
  120. # (max batch size to capture, max context len to capture / block size).
  121. self.graph_block_tables = np.zeros(
  122. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  123. dtype=np.int32)
  124. self.attn_backend = get_attn_backend(
  125. self.model_config.get_num_attention_heads(self.parallel_config),
  126. self.model_config.get_head_size(),
  127. self.model_config.get_num_kv_heads(self.parallel_config),
  128. self.model_config.get_sliding_window(),
  129. self.model_config.dtype,
  130. self.kv_cache_dtype,
  131. self.block_size,
  132. )
  133. # Lazy initialization
  134. self.model: nn.Module # Set after load_model
  135. # Set if the backend is flashinfer.
  136. self.flashinfer_workspace_buffer: torch.Tensor
  137. # Set after load_model.
  138. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  139. def load_model(self) -> None:
  140. with CudaMemoryProfiler() as m:
  141. # measure the time it takes to load the model
  142. start_time = time.time()
  143. self.model = get_model(
  144. model_config=self.model_config,
  145. device_config=self.device_config,
  146. load_config=self.load_config,
  147. lora_config=self.lora_config,
  148. vision_language_config=self.vision_language_config,
  149. parallel_config=self.parallel_config,
  150. scheduler_config=self.scheduler_config,
  151. cache_config=self.cache_config,
  152. )
  153. end_time = time.time()
  154. self.model_memory_usage = m.consumed_memory
  155. tp = get_tensor_model_parallel_world_size()
  156. total_time = end_time - start_time
  157. logger.info(
  158. f"Model weights loaded in {total_time:.2f} seconds.\nMemory usage: "
  159. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
  160. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  161. if self.lora_config:
  162. assert hasattr(self.model, "supported_lora_modules"
  163. ) and self.model.supported_lora_modules, (
  164. "Model does not support LoRA")
  165. assert hasattr(
  166. self.model,
  167. "embedding_modules"), "Model does not have embedding_modules"
  168. assert hasattr(self.model, "embedding_padding_modules"
  169. ), "Model does not have embedding_padding_modules"
  170. self.lora_manager = LRUCacheWorkerLoRAManager(
  171. self.scheduler_config.max_num_seqs,
  172. self.scheduler_config.max_num_batched_tokens,
  173. self.vocab_size,
  174. self.lora_config,
  175. self.device,
  176. self.model.embedding_modules,
  177. self.model.embedding_padding_modules,
  178. max_position_embeddings=self.model.config.
  179. max_position_embeddings,
  180. )
  181. self.model = self.lora_manager.create_lora_manager(self.model)
  182. if self.kv_cache_dtype == "fp8" and is_hip():
  183. # Currently scaled KV cache is only enabled on ROCm
  184. if self.model_config.quantization_param_path is not None:
  185. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  186. self.model.load_kv_cache_scales(
  187. self.model_config.quantization_param_path)
  188. else:
  189. raise RuntimeError("Using FP8 KV cache and scaling factors"
  190. " provided but model "
  191. f"{self.model.__class__} does not "
  192. "support loading scaling factors.")
  193. else:
  194. logger.warning(
  195. "Using FP8 KV cache but no scaling factors "
  196. "provided. Defaulting to scaling factors of 1.0. "
  197. "This may lead to less accurate results!")
  198. elif self.model_config.quantization_param_path is not None:
  199. logger.warning("KV cache scaling factors provided, "
  200. "but the KV cache data type is not FP8. "
  201. "KV cache scaling factors will not be used.")
  202. def save_sharded_state(
  203. self,
  204. path: str,
  205. pattern: Optional[str] = None,
  206. max_size: Optional[int] = None,
  207. ) -> None:
  208. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  209. ShardedStateLoader.save_model(
  210. self.model,
  211. path,
  212. pattern=pattern,
  213. max_size=max_size,
  214. )
  215. def get_max_block_per_batch(self) -> int:
  216. block_size = self.block_size
  217. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  218. def _prepare_model_input(
  219. self,
  220. seq_group_metadata_list: List[SequenceGroupMetadata],
  221. ) -> ModelInput:
  222. """Prepare the model input based on a given sequence group.
  223. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  224. The result tensors and data structure also batches input in prefill
  225. -> decode order. For example,
  226. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  227. - input_tokens[num_prefill_tokens:] contains decode tokens.
  228. If cuda graph is required, this API automatically pads inputs.
  229. """
  230. input_tokens: List[int] = []
  231. input_positions: List[int] = []
  232. slot_mapping: List[int] = []
  233. lora_index_mapping: List[int] = []
  234. lora_prompt_mapping: List[int] = []
  235. lora_requests: Set[LoRARequest] = set()
  236. seq_lens: List[int] = []
  237. prefill_seq_lens: List[int] = []
  238. decode_seq_lens: List[int] = []
  239. context_lens: List[int] = []
  240. query_lens: List[int] = []
  241. block_tables: List[List[int]] = []
  242. multi_modal_input_list: List[torch.Tensor] = []
  243. decode_only = True
  244. num_prefills = 0
  245. num_prefill_tokens = 0
  246. num_decode_tokens = 0
  247. # The following fields are only for flashinfer
  248. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  249. # for the precise definition of the following fields.
  250. # An example:
  251. # request 1, page indices [0, 5, 8]
  252. # request 2, page indices [1, 6, 7]
  253. # request 3, page indices [3, 4]
  254. # paged_kv_indices is a concatenation of page indices of all requests:
  255. # [0, 5, 8, 1, 6, 7, 3, 4]
  256. # paged_kv_indptr is used to index into paged_kv_indices:
  257. # [0, 3, 6, 8]
  258. paged_kv_indices: List[int] = []
  259. # 0 at the beginning of paged_kv_indptr indicates the start of the
  260. # first request’s page indices in the paged_kv_indices list.
  261. paged_kv_indptr: List[int] = [0]
  262. # paged_kv_last_page_len is the length of the last page of each request
  263. paged_kv_last_page_len: List[int] = []
  264. if len(seq_group_metadata_list) == 0:
  265. return ModelInput.empty(self.device)
  266. for seq_group_metadata in seq_group_metadata_list:
  267. seq_ids = list(seq_group_metadata.seq_data.keys())
  268. is_prompt = seq_group_metadata.is_prompt
  269. for seq_id in seq_ids:
  270. computed_block_nums = seq_group_metadata.computed_block_nums
  271. if (self.scheduler_config is not None
  272. and self.scheduler_config.chunked_prefill_enabled
  273. and not (computed_block_nums is None
  274. or computed_block_nums == [])):
  275. raise RuntimeError(
  276. "chunked prefill cannot be used with prefix caching "
  277. "now.")
  278. seq_data = seq_group_metadata.seq_data[seq_id]
  279. if is_prompt:
  280. context_len = seq_data.get_num_computed_tokens()
  281. else:
  282. # get_num_computed_tokens is incorrect for spec decoding.
  283. # So, we should have a special logic here.
  284. # TODO: Fix it.
  285. context_len = seq_data.get_len() - 1
  286. seq_len = min(
  287. seq_data.get_len(),
  288. context_len + seq_group_metadata.token_chunk_size)
  289. if is_prompt:
  290. tokens = seq_data.get_token_ids()[context_len:seq_len]
  291. else:
  292. # Optimization. get_token_ids requires the entire copy of
  293. # tokens.
  294. tokens = [seq_data.get_last_token_id()]
  295. # Prefix cache was hit.
  296. # Prefix is not supported with sliding_window
  297. prefix_cache_hit = (computed_block_nums is not None
  298. and len(computed_block_nums) > 0
  299. and self.sliding_window is None
  300. and is_prompt)
  301. # TODO: Combine chunked prefill and prefix caching by
  302. # only allowing multiple of block_size chunk size.
  303. # NOTE: This only works for oooooooxxx style attention.
  304. if prefix_cache_hit:
  305. assert computed_block_nums is not None
  306. context_len = len(computed_block_nums) * self.block_size
  307. tokens = tokens[context_len:]
  308. if self.attn_backend.get_name() == "flash-attn":
  309. # NOTE: For flash-attn, the block table should
  310. # include the entries for the incoming prefill tokens.
  311. # TODO: This is a temporary fix. We should
  312. # provide a unified interface for different backends.
  313. block_table = seq_group_metadata.block_tables[seq_id]
  314. else:
  315. block_table = computed_block_nums
  316. elif (self.scheduler_config.chunked_prefill_enabled
  317. or not is_prompt):
  318. if seq_group_metadata.block_tables is not None:
  319. # chunked prefill or decode
  320. block_table = seq_group_metadata.block_tables[seq_id]
  321. if self.sliding_window is not None:
  322. # chunked prefill doesn't support sliding window.
  323. assert (not self.scheduler_config.
  324. chunked_prefill_enabled)
  325. sliding_window_blocks = (self.sliding_window //
  326. self.block_size)
  327. block_table = block_table[-sliding_window_blocks:]
  328. if self.attn_backend.get_name() == "flashinfer":
  329. paged_kv_indices.extend(block_table)
  330. paged_kv_indptr.append(paged_kv_indptr[-1] +
  331. len(block_table))
  332. last_page_len = seq_data.get_len(
  333. ) % self.block_size
  334. if last_page_len == 0:
  335. last_page_len = self.block_size
  336. paged_kv_last_page_len.append(last_page_len)
  337. else:
  338. # Only happens when memory profiling runs.
  339. block_table = []
  340. else:
  341. # Prefill without chunked prefill or memory profiling.
  342. block_table = []
  343. block_tables.append(block_table)
  344. # TODO: This is a hack to make sliding window work with
  345. # paged attn. We can remove it if we make paged attn kernel
  346. # to properly handle slinding window attn.
  347. if (self.sliding_window is not None and not is_prompt):
  348. seq_len = min(seq_len, self.sliding_window)
  349. context_len = seq_len - 1
  350. seq_lens.append(seq_len)
  351. context_lens.append(context_len)
  352. query_len = seq_len - context_len
  353. query_lens.append(query_len)
  354. input_tokens.extend(tokens)
  355. input_positions.extend(list(range(context_len, seq_len)))
  356. lora_id = seq_group_metadata.lora_int_id
  357. if is_prompt:
  358. assert len(seq_ids) == 1
  359. num_prefills += 1
  360. num_prefill_tokens += len(tokens)
  361. decode_only = False
  362. prefill_seq_lens.append(seq_len)
  363. else:
  364. assert query_len == 1, (
  365. "seq_len: {}, context_len: {}, query_len: {}".format(
  366. seq_len, context_len, query_len))
  367. num_decode_tokens += query_len
  368. decode_seq_lens.append(seq_len)
  369. if lora_id > 0:
  370. lora_requests.add(seq_group_metadata.lora_request)
  371. lora_index_mapping += [lora_id] * (seq_len - context_len)
  372. lora_prompt_mapping.extend(
  373. [lora_id] *
  374. (seq_len -
  375. context_len if seq_group_metadata.sampling_params
  376. and seq_group_metadata.sampling_params.prompt_logprobs
  377. else 1))
  378. if seq_group_metadata.multi_modal_data:
  379. multi_modal_input_list.append(
  380. seq_group_metadata.multi_modal_data.data)
  381. if _is_block_tables_empty(seq_group_metadata.block_tables):
  382. # During memory profiling, the block tables are not
  383. # initialized yet. In this case, we just use a dummy
  384. # slot mapping.
  385. # In embeddings, the block tables are {seq_id: None}.
  386. slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  387. continue
  388. # Compute the slot mapping.
  389. block_table = seq_group_metadata.block_tables[seq_id]
  390. # Mask the [0, start_idx) tokens of the prompt with
  391. # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
  392. # sliding_window). For example, if the prompt len is 10,
  393. # sliding window is 8, and block size is 4, the first two
  394. # tokens are masked and the slot mapping will be
  395. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  396. start_idx = 0
  397. if self.sliding_window is not None:
  398. if is_prompt:
  399. assert context_len == 0, (
  400. "Prefix caching is currently not supported with "
  401. "sliding window attention")
  402. # It is an optimization. When it is decoding, it is always
  403. # 0. When prefill, we use it to not write slots to kv cache
  404. # to save memory.
  405. start_idx = max(0, query_len - self.sliding_window)
  406. for i in range(context_len, seq_len):
  407. if i < start_idx:
  408. slot_mapping.append(_PAD_SLOT_ID)
  409. continue
  410. block_number = block_table[i // self.block_size]
  411. block_offset = i % self.block_size
  412. slot = block_number * self.block_size + block_offset
  413. slot_mapping.append(slot)
  414. batch_size = len(input_tokens)
  415. max_query_len = max(query_lens)
  416. max_prefill_seq_len = max(prefill_seq_lens, default=0)
  417. max_decode_seq_len = max(decode_seq_lens, default=0)
  418. # If cuda graph can be used, pad tensors accordingly.
  419. # See `capture_model` API for more details.
  420. # vLLM uses cuda graph only for decoding requests.
  421. use_captured_graph = (
  422. decode_only and not self.model_config.enforce_eager
  423. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  424. and max_decode_seq_len <= self.max_seq_len_to_capture)
  425. if use_captured_graph:
  426. graph_batch_size = _get_graph_batch_size(batch_size)
  427. assert graph_batch_size >= batch_size
  428. for _ in range(graph_batch_size - batch_size):
  429. input_tokens.append(0)
  430. input_positions.append(0)
  431. slot_mapping.append(_PAD_SLOT_ID)
  432. seq_lens.append(1)
  433. block_tables.append([])
  434. lora_index_mapping.append(0)
  435. batch_size = graph_batch_size
  436. num_decode_tokens = batch_size
  437. if use_captured_graph:
  438. # The shape of graph_block_tables is
  439. # [max batch size, max context len // block size].
  440. input_block_tables = self.graph_block_tables[:batch_size]
  441. for i, block_table in enumerate(block_tables):
  442. if block_table:
  443. input_block_tables[i, :len(block_table)] = block_table
  444. block_tables = torch.tensor(input_block_tables, device=self.device)
  445. else:
  446. max_block_table_len = max(
  447. len(block_table) for block_table in block_tables)
  448. block_tables = make_tensor_with_pad(
  449. block_tables,
  450. max_len=max_block_table_len,
  451. pad=0,
  452. dtype=torch.int,
  453. device=self.device,
  454. )
  455. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  456. context_lens_tensor = torch.tensor(context_lens,
  457. dtype=torch.int,
  458. device=self.device)
  459. if multi_modal_input_list:
  460. assert self.vision_language_config, (
  461. "Multi-modal inputs are only supported by "
  462. "vision language models.")
  463. multi_modal_input = torch.cat(multi_modal_input_list,
  464. dim=0).to(self.device)
  465. else:
  466. multi_modal_input = None
  467. seq_lens_tensor = torch.tensor(seq_lens,
  468. dtype=torch.int,
  469. device=self.device)
  470. query_lens_tensor = torch.tensor(query_lens,
  471. dtype=torch.long,
  472. device=self.device)
  473. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  474. dtype=torch.int32,
  475. device=self.device)
  476. seq_lens_tensor = torch.tensor(seq_lens,
  477. dtype=torch.int,
  478. device=self.device)
  479. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  480. dtype=torch.int32,
  481. device=self.device)
  482. torch.cumsum(query_lens_tensor,
  483. dim=0,
  484. dtype=query_start_loc.dtype,
  485. out=query_start_loc[1:])
  486. torch.cumsum(seq_lens_tensor,
  487. dim=0,
  488. dtype=seq_start_loc.dtype,
  489. out=seq_start_loc[1:])
  490. input_tokens_tensor = torch.tensor(input_tokens,
  491. dtype=torch.long,
  492. device=self.device)
  493. input_positions_tensor = torch.tensor(input_positions,
  494. dtype=torch.long,
  495. device=self.device)
  496. slot_mapping_tensor = torch.tensor(slot_mapping,
  497. dtype=torch.long,
  498. device=self.device)
  499. if self.attn_backend.get_name() == "flashinfer":
  500. if not hasattr(self, "flashinfer_workspace_buffer"):
  501. # Allocate 16MB workspace buffer
  502. # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
  503. self.flashinfer_workspace_buffer = torch.empty(
  504. 16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
  505. paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
  506. dtype=torch.int,
  507. device=self.device)
  508. paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
  509. dtype=torch.int,
  510. device=self.device)
  511. paged_kv_last_page_len_tensor = torch.tensor(
  512. paged_kv_last_page_len, dtype=torch.int, device=self.device)
  513. kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
  514. self.model_config.dtype)
  515. attn_metadata = self.attn_backend.make_metadata(
  516. num_prefills=num_prefills,
  517. slot_mapping=slot_mapping_tensor,
  518. num_prefill_tokens=num_prefill_tokens,
  519. num_decode_tokens=num_decode_tokens,
  520. use_cuda_graph=False,
  521. max_prefill_seq_len=max_prefill_seq_len,
  522. block_tables=block_tables,
  523. workspace_buffer=self.flashinfer_workspace_buffer,
  524. paged_kv_indptr=paged_kv_indptr_tensor,
  525. paged_kv_indices=paged_kv_indices_tensor,
  526. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  527. num_qo_heads=self.model_config.get_num_attention_heads(
  528. self.parallel_config),
  529. num_kv_heads=self.model_config.get_num_kv_heads(
  530. self.parallel_config),
  531. head_dim=self.model_config.get_head_size(),
  532. page_size=16,
  533. seq_start_loc=seq_start_loc,
  534. data_type=kv_cache_dtype)
  535. else:
  536. attn_metadata = self.attn_backend.make_metadata(
  537. num_prefills=num_prefills,
  538. slot_mapping=slot_mapping_tensor,
  539. num_prefill_tokens=num_prefill_tokens,
  540. num_decode_tokens=num_decode_tokens,
  541. seq_lens=seq_lens,
  542. seq_lens_tensor=seq_lens_tensor,
  543. max_query_len=max_query_len,
  544. max_prefill_seq_len=max_prefill_seq_len,
  545. max_decode_seq_len=max_decode_seq_len,
  546. query_start_loc=query_start_loc,
  547. seq_start_loc=seq_start_loc,
  548. context_lens_tensor=context_lens_tensor,
  549. block_tables=block_tables,
  550. use_cuda_graph=use_captured_graph,
  551. )
  552. if self.lora_config:
  553. lora_mapping = LoRAMapping(
  554. lora_index_mapping,
  555. lora_prompt_mapping,
  556. )
  557. else:
  558. lora_mapping = None
  559. return ModelInput(
  560. input_tokens=input_tokens_tensor,
  561. input_positions=input_positions_tensor,
  562. attn_metadata=attn_metadata,
  563. seq_lens=seq_lens,
  564. query_lens=query_lens,
  565. lora_mapping=lora_mapping,
  566. lora_requests=lora_requests,
  567. multi_modal_input=multi_modal_input,
  568. slot_mapping=slot_mapping_tensor,
  569. num_prefill_tokens=num_prefill_tokens,
  570. num_decode_tokens=num_decode_tokens,
  571. num_prefills=num_prefills,
  572. )
  573. def prepare_input_tensors(
  574. self,
  575. seq_group_metadata_list: List[SequenceGroupMetadata],
  576. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  577. Set[LoRARequest], LoRAMapping, torch.Tensor]:
  578. if self.is_driver_worker:
  579. # Prepare input tensors.
  580. (
  581. input_tokens,
  582. input_positions,
  583. attn_metadata,
  584. seq_lens,
  585. query_lens,
  586. lora_mapping,
  587. lora_requests,
  588. multi_modal_input,
  589. slot_mapping,
  590. num_prefill_tokens,
  591. num_decode_tokens,
  592. num_prefills,
  593. ) = self._prepare_model_input(seq_group_metadata_list)
  594. sampling_metadata = SamplingMetadata.prepare(
  595. seq_group_metadata_list, seq_lens, query_lens, self.device,
  596. self.pin_memory)
  597. metadata_dict = {
  598. "input_tokens": input_tokens,
  599. "input_positions": input_positions,
  600. "selected_token_indices":
  601. sampling_metadata.selected_token_indices,
  602. "lora_requests": lora_requests,
  603. "lora_mapping": lora_mapping,
  604. "multi_modal_input": multi_modal_input,
  605. "num_prefill_tokens": num_prefill_tokens,
  606. "num_decode_tokens": num_decode_tokens,
  607. "slot_mapping": slot_mapping,
  608. "num_prefills": num_prefills,
  609. }
  610. if attn_metadata:
  611. metadata_dict.update(attn_metadata.asdict_zerocopy())
  612. broadcast_tensor_dict(metadata_dict, src=0)
  613. else:
  614. metadata_dict = broadcast_tensor_dict(src=0)
  615. input_tokens = metadata_dict.pop("input_tokens")
  616. input_positions = metadata_dict.pop("input_positions")
  617. selected_token_indices = metadata_dict.pop(
  618. "selected_token_indices")
  619. lora_mapping = metadata_dict.pop("lora_mapping")
  620. lora_requests = metadata_dict.pop("lora_requests")
  621. multi_modal_input = metadata_dict.pop("multi_modal_input")
  622. if metadata_dict:
  623. attn_metadata = self.attn_backend.make_metadata(
  624. **metadata_dict)
  625. else:
  626. attn_metadata = None
  627. sampling_metadata = SamplingMetadata(
  628. seq_groups=None,
  629. selected_token_indices=selected_token_indices,
  630. categorized_sample_indices=None,
  631. num_prompts=0,
  632. )
  633. return (input_tokens, input_positions, attn_metadata,
  634. sampling_metadata, lora_requests, lora_mapping,
  635. multi_modal_input)
  636. @torch.inference_mode()
  637. def execute_model(
  638. self,
  639. seq_group_metadata_list: List[SequenceGroupMetadata],
  640. kv_caches: List[torch.Tensor],
  641. ) -> Optional[SamplerOutput]:
  642. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  643. lora_requests, lora_mapping, multi_modal_input
  644. ) = self.prepare_input_tensors(seq_group_metadata_list)
  645. if self.lora_config:
  646. self.set_active_loras(lora_requests, lora_mapping)
  647. # Currently cuda graph is only supported by the decode phase.
  648. prefill_meta = attn_metadata.prefill_metadata
  649. decode_meta = attn_metadata.decode_metadata
  650. if prefill_meta is None and decode_meta.use_cuda_graph:
  651. graph_batch_size = input_tokens.shape[0]
  652. model_executable = self.graph_runners[graph_batch_size]
  653. else:
  654. model_executable = self.model
  655. execute_model_kwargs = {
  656. "input_ids": input_tokens,
  657. "positions": input_positions,
  658. "kv_caches": kv_caches,
  659. "attn_metadata": attn_metadata,
  660. }
  661. if self.vision_language_config:
  662. execute_model_kwargs.update({"image_input": multi_modal_input})
  663. hidden_states = model_executable(**execute_model_kwargs)
  664. # Compute the logits.
  665. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  666. # Only perform sampling in the driver worker.
  667. if not self.is_driver_worker:
  668. return None
  669. # Sample the next token.
  670. output = self.model.sample(
  671. logits=logits,
  672. sampling_metadata=sampling_metadata,
  673. )
  674. return output
  675. @torch.inference_mode()
  676. def profile_run(self) -> None:
  677. # Enable top-k sampling to reflect the accurate memory usage.
  678. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  679. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  680. max_num_seqs = self.scheduler_config.max_num_seqs
  681. # This represents the maximum number of different requests
  682. # that will have unique loras, an therefore the max amount of memory
  683. # consumption create dummy lora request copies from the lora request
  684. # passed in, which contains a lora from the lora warmup path.
  685. dummy_lora_requests = []
  686. dummy_lora_requests_per_seq = []
  687. if self.lora_config:
  688. assert self.lora_manager is not None
  689. with self.lora_manager.dummy_lora_cache():
  690. for idx in range(self.lora_config.max_loras):
  691. lora_id = idx + 1
  692. dummy_lora_request = LoRARequest(
  693. lora_name=f"warmup_{lora_id}",
  694. lora_int_id=lora_id,
  695. lora_local_path="/not/a/real/path",
  696. )
  697. self.lora_manager.add_dummy_lora(dummy_lora_request,
  698. rank=LORA_WARMUP_RANK)
  699. dummy_lora_requests.append(dummy_lora_request)
  700. dummy_lora_requests_per_seq = [
  701. dummy_lora_requests[idx % len(dummy_lora_requests)]
  702. for idx in range(max_num_seqs)
  703. ]
  704. # Profile memory usage with max_num_sequences sequences and the total
  705. # number of tokens equal to max_num_batched_tokens.
  706. seqs: List[SequenceGroupMetadata] = []
  707. # Additional GPU memory may be needed for vision encoding, which needs
  708. # to be accounted for when calculating the GPU blocks for
  709. # Aphrodite blocker manager.
  710. # To exercise the worst scenario for GPU memory consumption,
  711. # the number of seqs (batch_size) is chosen to maximize the number
  712. # of images processed.
  713. if self.vision_language_config:
  714. max_num_seqs = min(
  715. max_num_seqs,
  716. int(max_num_batched_tokens /
  717. self.vision_language_config.image_feature_size))
  718. for group_id in range(max_num_seqs):
  719. seq_len = (max_num_batched_tokens // max_num_seqs +
  720. (group_id < max_num_batched_tokens % max_num_seqs))
  721. seq_data, fake_multi_modal_input = _prepare_fake_inputs(
  722. seq_len, self.vision_language_config)
  723. seq = SequenceGroupMetadata(
  724. request_id=str(group_id),
  725. is_prompt=True,
  726. seq_data={group_id: seq_data},
  727. sampling_params=sampling_params,
  728. block_tables=None,
  729. lora_request=dummy_lora_requests_per_seq[group_id]
  730. if dummy_lora_requests_per_seq else None,
  731. multi_modal_data=fake_multi_modal_input,
  732. )
  733. seqs.append(seq)
  734. # Run the model with the dummy inputs.
  735. num_layers = self.model_config.get_num_layers(self.parallel_config)
  736. kv_caches = [None] * num_layers
  737. self.execute_model(seqs, kv_caches)
  738. torch.cuda.synchronize()
  739. return
  740. def remove_all_loras(self):
  741. if not self.lora_manager:
  742. raise RuntimeError("LoRA is not enabled.")
  743. self.lora_manager.remove_all_loras()
  744. def set_active_loras(self, lora_requests: Set[LoRARequest],
  745. lora_mapping: LoRAMapping) -> None:
  746. if not self.lora_manager:
  747. raise RuntimeError("LoRA is not enabled.")
  748. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  749. def add_lora(self, lora_request: LoRARequest) -> bool:
  750. if not self.lora_manager:
  751. raise RuntimeError("LoRA is not enabled.")
  752. return self.lora_manager.add_lora(lora_request)
  753. def remove_lora(self, lora_id: int) -> bool:
  754. if not self.lora_manager:
  755. raise RuntimeError("LoRA is not enabled.")
  756. return self.lora_manager.remove_lora(lora_id)
  757. def list_loras(self) -> Set[int]:
  758. if not self.lora_manager:
  759. raise RuntimeError("LoRA is not enabled.")
  760. return self.lora_manager.list_loras()
  761. @torch.inference_mode()
  762. def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
  763. """Cuda graph capture a model.
  764. Note that CUDA graph's performance gain is negligible if number
  765. of batched tokens are larger than 200. And since CUDA graph
  766. requires fixed sized tensors, supporting large/variable batch
  767. size requires high GPU memory overhead. Thus, vLLM only captures
  768. decoding requests. Mixed batch (chunked prefill + decoding) or
  769. prefill requests are not captured.
  770. Since it is used for decoding-only, it assumes there's only 1 token
  771. per sequence in the batch.
  772. """
  773. assert not self.model_config.enforce_eager
  774. logger.info("Capturing the model for CUDA graphs. This may lead to "
  775. "unexpected consequences if the model is not static. To "
  776. "run the model in eager mode, set 'enforce_eager=True' or "
  777. "use '--enforce-eager' in the CLI.")
  778. logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
  779. "If you are running out of memory, consider decreasing "
  780. "`gpu_memory_utilization` or enforcing eager mode. "
  781. "You can also reduce the `max_num_seqs` as needed "
  782. "to decrease memory usage.")
  783. start_time = time.perf_counter()
  784. # Prepare dummy inputs. These will be reused for all batch sizes.
  785. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  786. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  787. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  788. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  789. slot_mapping.fill_(_PAD_SLOT_ID)
  790. seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  791. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  792. graph_batch_size = _get_graph_batch_size(
  793. self.scheduler_config.max_num_seqs)
  794. batch_size_capture_list = [
  795. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  796. ]
  797. with graph_capture() as graph_capture_context:
  798. # NOTE: Capturing the largest batch size first may help reduce the
  799. # memory usage of CUDA graph.
  800. for batch_size in reversed(batch_size_capture_list):
  801. # Create dummy attn_metadata.
  802. attn_metadata = self.attn_backend.make_metadata(
  803. num_prefills=0,
  804. num_prefill_tokens=0,
  805. num_decode_tokens=batch_size,
  806. slot_mapping=slot_mapping[:batch_size],
  807. seq_lens=None,
  808. seq_lens_tensor=seq_lens[:batch_size],
  809. max_query_len=None,
  810. max_prefill_seq_len=0,
  811. max_decode_seq_len=self.max_seq_len_to_capture,
  812. query_start_loc=None,
  813. seq_start_loc=None,
  814. context_lens_tensor=None,
  815. block_tables=block_tables[:batch_size],
  816. use_cuda_graph=True,
  817. )
  818. if self.lora_config:
  819. lora_mapping = LoRAMapping(
  820. [0] * batch_size,
  821. [0] * batch_size,
  822. )
  823. self.set_active_loras(set(), lora_mapping)
  824. graph_runner = CUDAGraphRunner(self.model)
  825. graph_runner.capture(
  826. input_tokens[:batch_size],
  827. input_positions[:batch_size],
  828. kv_caches,
  829. attn_metadata,
  830. memory_pool=self.graph_memory_pool,
  831. stream=graph_capture_context.stream,
  832. )
  833. self.graph_memory_pool = graph_runner.graph.pool()
  834. self.graph_runners[batch_size] = graph_runner
  835. end_time = time.perf_counter()
  836. elapsed_time = end_time - start_time
  837. # This usually takes < 10 seconds.
  838. logger.info(f"Graph capturing finished in {elapsed_time} secs.")
  839. def __del__(self) -> None:
  840. # Delete the CUDA graphs before deleting the pynccl communicator.
  841. # NOTE: This is necessary because otherwise deadlocks can
  842. # happen.
  843. # FIXME: This is a bit hacky. Find a more robust solution.
  844. # TODO: when we get enough user feedback that pynccl is
  845. # more stable than cupy, we can remove this, e.g. in v0.4.1.
  846. self.graph_runners.clear()
  847. self.pynccl_backend = None
  848. @property
  849. def vocab_size(self) -> int:
  850. return self.model_config.get_vocab_size()
  851. class CUDAGraphRunner:
  852. def __init__(self, model: nn.Module):
  853. self.model = model
  854. self.input_buffers: Dict[str, torch.Tensor] = {}
  855. self.output_buffers: Dict[str, torch.Tensor] = {}
  856. self._graph: Optional[torch.cuda.CUDAGraph] = None
  857. @property
  858. def graph(self):
  859. assert self._graph is not None
  860. return self._graph
  861. def capture(
  862. self,
  863. input_ids: torch.Tensor,
  864. positions: torch.Tensor,
  865. kv_caches: List[torch.Tensor],
  866. attn_metadata: AttentionMetadata,
  867. memory_pool: Optional[Tuple[int, int]],
  868. stream: torch.cuda.Stream,
  869. **kwargs,
  870. ) -> None:
  871. assert self._graph is None
  872. # Run the model once without capturing the graph.
  873. # This is to make sure that the captured graph does not include the
  874. # kernel launches for initial benchmarking (e.g., Triton autotune).
  875. self.model(
  876. input_ids,
  877. positions,
  878. kv_caches,
  879. attn_metadata,
  880. **kwargs,
  881. )
  882. torch.cuda.synchronize()
  883. # Capture the graph.
  884. self._graph = torch.cuda.CUDAGraph()
  885. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  886. hidden_states = self.model(
  887. input_ids,
  888. positions,
  889. kv_caches,
  890. attn_metadata,
  891. **kwargs,
  892. )
  893. torch.cuda.synchronize()
  894. # Save the input and output buffers.
  895. self.input_buffers = {
  896. "input_ids": input_ids,
  897. "positions": positions,
  898. "kv_caches": kv_caches,
  899. "slot_mapping": attn_metadata.slot_mapping,
  900. "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
  901. "block_tables": attn_metadata.decode_metadata.block_tables,
  902. }
  903. self.output_buffers = {"hidden_states": hidden_states}
  904. return
  905. def forward(
  906. self,
  907. input_ids: torch.Tensor,
  908. positions: torch.Tensor,
  909. kv_caches: List[torch.Tensor],
  910. attn_metadata: AttentionMetadata,
  911. **kwargs,
  912. ) -> torch.Tensor:
  913. # KV caches are fixed tensors, so we don't need to copy them.
  914. del kv_caches
  915. # Copy the input tensors to the input buffers.
  916. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  917. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  918. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  919. non_blocking=True)
  920. self.input_buffers["seq_lens_tensor"].copy_(
  921. attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
  922. self.input_buffers["block_tables"].copy_(
  923. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  924. # Run the graph.
  925. self.graph.replay()
  926. # Return the output tensor.
  927. return self.output_buffers["hidden_states"]
  928. def __call__(self, *args, **kwargs):
  929. return self.forward(*args, **kwargs)
  930. def _get_graph_batch_size(batch_size: int) -> int:
  931. """Returns the padded batch size given actual batch size.
  932. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  933. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  934. """
  935. if batch_size <= 2:
  936. return batch_size
  937. elif batch_size <= 4:
  938. return 4
  939. else:
  940. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  941. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  942. def _prepare_fake_inputs(
  943. seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
  944. """Prepare fake inputs for profile run."""
  945. if vision_language_config:
  946. prompt_tokens = [
  947. vision_language_config.image_token_id
  948. ] * vision_language_config.image_feature_size + [0] * (
  949. seq_len - vision_language_config.image_feature_size)
  950. fake_image_input = MultiModalData(
  951. type=MultiModalData.Type.IMAGE,
  952. data=torch.zeros(vision_language_config.image_input_shape,
  953. dtype=torch.float16))
  954. else:
  955. prompt_tokens = [0] * seq_len
  956. fake_image_input = None
  957. return SequenceData(prompt_tokens), fake_image_input
  958. def _is_block_tables_empty(block_tables: Union[None, Dict]):
  959. """
  960. Check if block_tables is None or a dictionary with all None values.
  961. """
  962. if block_tables is None:
  963. return True
  964. if isinstance(block_tables, dict) and all(
  965. value is None for value in block_tables.values()):
  966. return True
  967. return False