model_runner.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083
  1. import time
  2. import warnings
  3. from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from loguru import logger
  8. from aphrodite.attention import AttentionMetadata, get_attn_backend
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. LoRAConfig, ModelConfig, ParallelConfig,
  11. SchedulerConfig, VisionLanguageConfig)
  12. from aphrodite.common.sampling_params import SamplingParams
  13. from aphrodite.common.sequence import (MultiModalData, SamplerOutput,
  14. SequenceData, SequenceGroupMetadata)
  15. from aphrodite.common.utils import (CudaMemoryProfiler,
  16. get_kv_cache_torch_dtype, is_hip,
  17. is_pin_memory_available,
  18. make_tensor_with_pad)
  19. from aphrodite.distributed import broadcast_tensor_dict
  20. from aphrodite.distributed.communication_op import graph_capture
  21. from aphrodite.distributed.parallel_state import (
  22. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  23. from aphrodite.lora.layers import LoRAMapping
  24. from aphrodite.lora.request import LoRARequest
  25. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  26. from aphrodite.modeling import SamplingMetadata
  27. from aphrodite.modeling.model_loader import get_model
  28. _PAD_SLOT_ID = -1
  29. LORA_WARMUP_RANK = 8
  30. _BATCH_SIZE_ALIGNMENT = 8
  31. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  32. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  33. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  34. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  35. ]
  36. class ModelInput(NamedTuple):
  37. input_tokens: torch.Tensor
  38. input_positions: torch.Tensor
  39. attn_metadata: Optional[AttentionMetadata]
  40. seq_lens: List[int]
  41. query_lens: List[int]
  42. lora_mapping: Optional[LoRAMapping]
  43. lora_requests: Set[LoRARequest]
  44. multi_modal_input: Optional[torch.Tensor]
  45. slot_mapping: torch.Tensor
  46. num_prefill_tokens: int
  47. num_decode_tokens: int
  48. num_prefills: int
  49. @classmethod
  50. def empty(cls, device):
  51. return ModelInput(
  52. input_tokens=torch.empty(0, device=device),
  53. input_positions=torch.empty(0, device=device),
  54. attn_metadata=None,
  55. seq_lens=[],
  56. query_lens=[],
  57. lora_mapping=None,
  58. lora_requests=set(),
  59. multi_modal_input=None,
  60. slot_mapping=torch.empty(0, device=device),
  61. num_prefill_tokens=0,
  62. num_decode_tokens=0,
  63. num_prefills=0,
  64. )
  65. class ModelRunner:
  66. def __init__(
  67. self,
  68. model_config: ModelConfig,
  69. parallel_config: ParallelConfig,
  70. scheduler_config: SchedulerConfig,
  71. device_config: DeviceConfig,
  72. cache_config: CacheConfig,
  73. load_config: LoadConfig,
  74. lora_config: Optional[LoRAConfig],
  75. kv_cache_dtype: Optional[str] = "auto",
  76. is_driver_worker: bool = False,
  77. vision_language_config: Optional[VisionLanguageConfig] = None,
  78. ):
  79. self.model_config = model_config
  80. self.parallel_config = parallel_config
  81. self.scheduler_config = scheduler_config
  82. self.device_config = device_config
  83. self.cache_config = cache_config
  84. self.lora_config = lora_config
  85. self.load_config = load_config
  86. self.is_driver_worker = is_driver_worker
  87. self.vision_language_config = vision_language_config
  88. self.device = self.device_config.device
  89. self.pin_memory = is_pin_memory_available()
  90. self.kv_cache_dtype = kv_cache_dtype
  91. self.sliding_window = model_config.get_sliding_window()
  92. self.block_size = cache_config.block_size
  93. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  94. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  95. self.graph_memory_pool: Optional[Tuple[
  96. int, int]] = None # Set during graph capture.
  97. # When using CUDA graph, the input block tables must be padded to
  98. # max_seq_len_to_capture. However, creating the block table in
  99. # Python can be expensive. To optimize this, we cache the block table
  100. # in numpy and only copy the actual input content at every iteration.
  101. # The shape of the cached block table will be
  102. # (max batch size to capture, max context len to capture / block size).
  103. self.graph_block_tables = np.zeros(
  104. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  105. dtype=np.int32)
  106. self.attn_backend = get_attn_backend(
  107. self.model_config.get_num_attention_heads(self.parallel_config),
  108. self.model_config.get_head_size(),
  109. self.model_config.get_num_kv_heads(self.parallel_config),
  110. self.model_config.get_sliding_window(),
  111. self.model_config.dtype,
  112. self.kv_cache_dtype,
  113. self.block_size,
  114. )
  115. # Lazy initialization
  116. self.model: nn.Module # Set after load_model
  117. # Set if the backend is flashinfer.
  118. self.flashinfer_workspace_buffer: torch.Tensor
  119. # Set after load_model.
  120. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  121. def load_model(self) -> None:
  122. with CudaMemoryProfiler() as m:
  123. # measure the time it takes to load the model
  124. start_time = time.time()
  125. self.model = get_model(
  126. model_config=self.model_config,
  127. device_config=self.device_config,
  128. load_config=self.load_config,
  129. lora_config=self.lora_config,
  130. vision_language_config=self.vision_language_config,
  131. parallel_config=self.parallel_config,
  132. scheduler_config=self.scheduler_config,
  133. cache_config=self.cache_config,
  134. )
  135. end_time = time.time()
  136. self.model_memory_usage = m.consumed_memory
  137. tp = get_tensor_model_parallel_world_size()
  138. rank = get_tensor_model_parallel_rank()
  139. total_time = end_time - start_time
  140. if tp > 1:
  141. logger.info(
  142. f"Rank {rank}: Model weights loaded in {total_time:.2f} secs.")
  143. if rank == 0:
  144. logger.info(
  145. "Memory usage: "
  146. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} ="
  147. f" {self.model_memory_usage * tp / float(2**30):.2f} GiB")
  148. else:
  149. logger.info(f"Model weights loaded in {total_time:.2f} seconds.")
  150. logger.info("Memory usage: "
  151. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  152. if self.lora_config:
  153. assert hasattr(self.model, "supported_lora_modules"
  154. ) and self.model.supported_lora_modules, (
  155. "Model does not support LoRA")
  156. assert hasattr(
  157. self.model,
  158. "embedding_modules"), "Model does not have embedding_modules"
  159. assert hasattr(self.model, "embedding_padding_modules"
  160. ), "Model does not have embedding_padding_modules"
  161. self.lora_manager = LRUCacheWorkerLoRAManager(
  162. self.scheduler_config.max_num_seqs,
  163. self.scheduler_config.max_num_batched_tokens,
  164. self.vocab_size,
  165. self.lora_config,
  166. self.device,
  167. self.model.embedding_modules,
  168. self.model.embedding_padding_modules,
  169. max_position_embeddings=self.model.config.
  170. max_position_embeddings,
  171. )
  172. self.model = self.lora_manager.create_lora_manager(self.model)
  173. if self.kv_cache_dtype == "fp8" and is_hip():
  174. # Currently only ROCm accepts kv-cache scaling factors
  175. # via quantization_param_path and this will be deprecated
  176. # in the future.
  177. if self.model_config.quantization_param_path is not None:
  178. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  179. warnings.warn(
  180. "Loading kv cache scaling factor from JSON is "
  181. "deprecated and will be removed. Please include "
  182. "kv cache scaling factors in the model checkpoint.",
  183. FutureWarning,
  184. stacklevel=2)
  185. self.model.load_kv_cache_scales(
  186. self.model_config.quantization_param_path)
  187. logger.info(
  188. "Loaded KV cache scaling factors from ",
  189. f"{self.model_config.quantization_param_path}")
  190. else:
  191. raise RuntimeError(
  192. "Using FP8 KV cache and scaling factors provided but "
  193. f"model {self.model.__class__} does not support loading"
  194. " scaling factors.", )
  195. else:
  196. logger.warning(
  197. "Using FP8 KV cache but no scaling factors "
  198. "provided. Defaulting to scaling factors of 1.0. "
  199. "This may lead to less accurate results!")
  200. def save_sharded_state(
  201. self,
  202. path: str,
  203. pattern: Optional[str] = None,
  204. max_size: Optional[int] = None,
  205. ) -> None:
  206. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  207. ShardedStateLoader.save_model(
  208. self.model,
  209. path,
  210. pattern=pattern,
  211. max_size=max_size,
  212. )
  213. def get_max_block_per_batch(self) -> int:
  214. block_size = self.block_size
  215. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  216. def _prepare_model_input(
  217. self,
  218. seq_group_metadata_list: List[SequenceGroupMetadata],
  219. ) -> ModelInput:
  220. """Prepare the model input based on a given sequence group.
  221. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
  222. The result tensors and data structure also batches input in prefill
  223. -> decode order. For example,
  224. - input_tokens[:num_prefill_tokens] contains prefill tokens.
  225. - input_tokens[num_prefill_tokens:] contains decode tokens.
  226. If cuda graph is required, this API automatically pads inputs.
  227. """
  228. input_tokens: List[int] = []
  229. input_positions: List[int] = []
  230. slot_mapping: List[int] = []
  231. lora_index_mapping: List[int] = []
  232. lora_prompt_mapping: List[int] = []
  233. lora_requests: Set[LoRARequest] = set()
  234. seq_lens: List[int] = []
  235. prefill_seq_lens: List[int] = []
  236. decode_seq_lens: List[int] = []
  237. context_lens: List[int] = []
  238. query_lens: List[int] = []
  239. block_tables: List[List[int]] = []
  240. multi_modal_input_list: List[torch.Tensor] = []
  241. decode_only = True
  242. num_prefills = 0
  243. num_prefill_tokens = 0
  244. num_decode_tokens = 0
  245. # The following fields are only for flashinfer
  246. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  247. # for the precise definition of the following fields.
  248. # An example:
  249. # request 1, page indices [0, 5, 8]
  250. # request 2, page indices [1, 6, 7]
  251. # request 3, page indices [3, 4]
  252. # paged_kv_indices is a concatenation of page indices of all requests:
  253. # [0, 5, 8, 1, 6, 7, 3, 4]
  254. # paged_kv_indptr is used to index into paged_kv_indices:
  255. # [0, 3, 6, 8]
  256. paged_kv_indices: List[int] = []
  257. # 0 at the beginning of paged_kv_indptr indicates the start of the
  258. # first request’s page indices in the paged_kv_indices list.
  259. paged_kv_indptr: List[int] = [0]
  260. # paged_kv_last_page_len is the length of the last page of each request
  261. paged_kv_last_page_len: List[int] = []
  262. if len(seq_group_metadata_list) == 0:
  263. return ModelInput.empty(self.device)
  264. if self.sliding_window is not None:
  265. sliding_window_blocks = (self.sliding_window + self.block_size -
  266. 1) // self.block_size
  267. block_aligned_sliding_window = \
  268. sliding_window_blocks * self.block_size
  269. for seq_group_metadata in seq_group_metadata_list:
  270. seq_ids = list(seq_group_metadata.seq_data.keys())
  271. is_prompt = seq_group_metadata.is_prompt
  272. for seq_id in seq_ids:
  273. computed_block_nums = seq_group_metadata.computed_block_nums
  274. if (self.scheduler_config is not None
  275. and self.scheduler_config.chunked_prefill_enabled
  276. and not (computed_block_nums is None
  277. or computed_block_nums == [])):
  278. raise RuntimeError(
  279. "chunked prefill cannot be used with prefix caching "
  280. "now.")
  281. seq_data = seq_group_metadata.seq_data[seq_id]
  282. if is_prompt:
  283. context_len = seq_data.get_num_computed_tokens()
  284. else:
  285. # get_num_computed_tokens is incorrect for spec decoding.
  286. # So, we should have a special logic here.
  287. # TODO: Fix it.
  288. context_len = seq_data.get_len() - 1
  289. seq_len = min(
  290. seq_data.get_len(),
  291. context_len + seq_group_metadata.token_chunk_size)
  292. if is_prompt:
  293. tokens = seq_data.get_token_ids()[context_len:seq_len]
  294. else:
  295. # Optimization. get_token_ids requires the entire copy of
  296. # tokens.
  297. tokens = [seq_data.get_last_token_id()]
  298. # Prefix cache was hit.
  299. # Prefix is not supported with sliding_window
  300. prefix_cache_hit = (computed_block_nums is not None
  301. and len(computed_block_nums) > 0
  302. and self.sliding_window is None
  303. and is_prompt)
  304. # These are seq_len/context_len capped to the sliding window.
  305. # They are passed to decode kernel.
  306. # We still need original seq_len/context_len to compute slot
  307. # mapping (and input position) below.
  308. curr_sliding_window_blocks = None
  309. sliding_seq_len = seq_len
  310. sliding_context_len = context_len
  311. # TODO: This is a hack to make sliding window work with
  312. # paged attn. We can remove it if we make paged attn kernel
  313. # to properly handle slinding window attn.
  314. if (self.sliding_window is not None and not is_prompt):
  315. curr_sliding_window_blocks = sliding_window_blocks
  316. if self.scheduler_config.use_v2_block_manager:
  317. # number of elements in last block
  318. suff_len = seq_len % self.block_size
  319. sliding_seq_len = min(
  320. seq_len, block_aligned_sliding_window + suff_len)
  321. if suff_len > 0:
  322. curr_sliding_window_blocks += 1
  323. else:
  324. sliding_seq_len = min(seq_len, self.sliding_window)
  325. sliding_context_len = sliding_seq_len - 1
  326. # TODO: Combine chunked prefill and prefix caching by
  327. # only allowing multiple of block_size chunk size.
  328. # NOTE: This only works for oooooooxxx style attention.
  329. if prefix_cache_hit:
  330. assert computed_block_nums is not None
  331. context_len = len(computed_block_nums) * self.block_size
  332. tokens = tokens[context_len:]
  333. # need to think what to set it to when we have both sliding
  334. # window and prefix caching...
  335. assert self.sliding_window is None, \
  336. "Prefix caching is not supported with sliding window"
  337. sliding_context_len = context_len
  338. if self.attn_backend.get_name() == "flash-attn":
  339. # NOTE: For flash-attn, the block table should
  340. # include the entries for the incoming prefill tokens.
  341. # TODO: This is a temporary fix. We should
  342. # provide a unified interface for different backends.
  343. block_table = seq_group_metadata.block_tables[seq_id]
  344. else:
  345. block_table = computed_block_nums
  346. elif (self.scheduler_config.chunked_prefill_enabled
  347. or not is_prompt):
  348. if seq_group_metadata.block_tables is not None:
  349. # chunked prefill or decode
  350. block_table = seq_group_metadata.block_tables[seq_id]
  351. if curr_sliding_window_blocks is not None:
  352. block_table = block_table[
  353. -curr_sliding_window_blocks:]
  354. if self.attn_backend.get_name() == "flashinfer":
  355. paged_kv_indices.extend(block_table)
  356. paged_kv_indptr.append(paged_kv_indptr[-1] +
  357. len(block_table))
  358. last_page_len = seq_data.get_len(
  359. ) % self.block_size
  360. if last_page_len == 0:
  361. last_page_len = self.block_size
  362. paged_kv_last_page_len.append(last_page_len)
  363. else:
  364. # Only happens when memory profiling runs.
  365. block_table = []
  366. else:
  367. # Prefill without chunked prefill or memory profiling.
  368. block_table = []
  369. block_tables.append(block_table)
  370. seq_lens.append(sliding_seq_len)
  371. context_lens.append(sliding_context_len)
  372. query_len = sliding_seq_len - sliding_context_len
  373. query_lens.append(query_len)
  374. input_tokens.extend(tokens)
  375. input_positions.extend(list(range(context_len, seq_len)))
  376. lora_id = seq_group_metadata.lora_int_id
  377. if is_prompt:
  378. assert len(seq_ids) == 1
  379. num_prefills += 1
  380. num_prefill_tokens += len(tokens)
  381. decode_only = False
  382. prefill_seq_lens.append(seq_len)
  383. else:
  384. assert query_len == 1, (
  385. "seq_len: {}, context_len: {}, query_len: {}".format(
  386. seq_len, context_len, query_len))
  387. num_decode_tokens += query_len
  388. decode_seq_lens.append(sliding_seq_len)
  389. if lora_id > 0:
  390. lora_requests.add(seq_group_metadata.lora_request)
  391. lora_index_mapping += [lora_id] * query_len
  392. lora_prompt_mapping.extend(
  393. [lora_id] *
  394. (query_len if seq_group_metadata.sampling_params
  395. and seq_group_metadata.sampling_params.prompt_logprobs
  396. else 1))
  397. if seq_group_metadata.multi_modal_data:
  398. multi_modal_input_list.append(
  399. seq_group_metadata.multi_modal_data.data)
  400. if _is_block_tables_empty(seq_group_metadata.block_tables):
  401. # During memory profiling, the block tables are not
  402. # initialized yet. In this case, we just use a dummy
  403. # slot mapping.
  404. # In embeddings, the block tables are {seq_id: None}.
  405. slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  406. continue
  407. # Compute the slot mapping.
  408. block_table = seq_group_metadata.block_tables[seq_id]
  409. # Mask the [0, start_idx) tokens of the prompt with
  410. # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
  411. # sliding_window). For example, if the prompt len is 10,
  412. # sliding window is 8, and block size is 4, the first two
  413. # tokens are masked and the slot mapping will be
  414. # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  415. start_idx = 0
  416. if self.sliding_window is not None:
  417. if is_prompt:
  418. assert self.scheduler_config.use_v2_block_manager \
  419. or context_len == 0, (
  420. "Prefix caching is currently not supported with "
  421. "sliding window attention in V1 block manager")
  422. # It is an optimization. When it is decoding, it is always
  423. # 0. When prefill, we use it to not write slots to kv cache
  424. # to save memory.
  425. start_idx = max(0, query_len - self.sliding_window)
  426. for i in range(context_len, seq_len):
  427. if i < start_idx:
  428. slot_mapping.append(_PAD_SLOT_ID)
  429. continue
  430. block_number = block_table[i // self.block_size]
  431. block_offset = i % self.block_size
  432. slot = block_number * self.block_size + block_offset
  433. slot_mapping.append(slot)
  434. batch_size = len(input_tokens)
  435. max_query_len = max(query_lens)
  436. max_prefill_seq_len = max(prefill_seq_lens, default=0)
  437. max_decode_seq_len = max(decode_seq_lens, default=0)
  438. # If cuda graph can be used, pad tensors accordingly.
  439. # See `capture_model` API for more details.
  440. # Aphrodite uses cuda graph only for decoding requests.
  441. use_captured_graph = (
  442. decode_only and not self.model_config.enforce_eager
  443. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  444. and max_decode_seq_len <= self.max_seq_len_to_capture)
  445. if use_captured_graph:
  446. graph_batch_size = _get_graph_batch_size(batch_size)
  447. assert graph_batch_size >= batch_size
  448. for _ in range(graph_batch_size - batch_size):
  449. input_tokens.append(0)
  450. input_positions.append(0)
  451. slot_mapping.append(_PAD_SLOT_ID)
  452. seq_lens.append(1)
  453. block_tables.append([])
  454. lora_index_mapping.append(0)
  455. batch_size = graph_batch_size
  456. num_decode_tokens = batch_size
  457. if use_captured_graph:
  458. # The shape of graph_block_tables is
  459. # [max batch size, max context len // block size].
  460. input_block_tables = self.graph_block_tables[:batch_size]
  461. for i, block_table in enumerate(block_tables):
  462. if block_table:
  463. input_block_tables[i, :len(block_table)] = block_table
  464. block_tables = torch.tensor(input_block_tables, device=self.device)
  465. else:
  466. max_block_table_len = max(
  467. len(block_table) for block_table in block_tables)
  468. block_tables = make_tensor_with_pad(
  469. block_tables,
  470. max_len=max_block_table_len,
  471. pad=0,
  472. dtype=torch.int,
  473. device=self.device,
  474. )
  475. assert max_query_len > 0, ("query_lens: {}".format(query_lens))
  476. context_lens_tensor = torch.tensor(context_lens,
  477. dtype=torch.int,
  478. device=self.device)
  479. if multi_modal_input_list:
  480. assert self.vision_language_config, (
  481. "Multi-modal inputs are only supported by "
  482. "vision language models.")
  483. multi_modal_input = torch.cat(multi_modal_input_list,
  484. dim=0).to(self.device)
  485. else:
  486. multi_modal_input = None
  487. query_lens_tensor = torch.tensor(query_lens,
  488. dtype=torch.long,
  489. device=self.device)
  490. query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  491. dtype=torch.int32,
  492. device=self.device)
  493. seq_lens_tensor = torch.tensor(seq_lens,
  494. dtype=torch.int,
  495. device=self.device)
  496. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  497. dtype=torch.int32,
  498. device=self.device)
  499. torch.cumsum(query_lens_tensor,
  500. dim=0,
  501. dtype=query_start_loc.dtype,
  502. out=query_start_loc[1:])
  503. torch.cumsum(seq_lens_tensor,
  504. dim=0,
  505. dtype=seq_start_loc.dtype,
  506. out=seq_start_loc[1:])
  507. input_tokens_tensor = torch.tensor(input_tokens,
  508. dtype=torch.long,
  509. device=self.device)
  510. input_positions_tensor = torch.tensor(input_positions,
  511. dtype=torch.long,
  512. device=self.device)
  513. slot_mapping_tensor = torch.tensor(slot_mapping,
  514. dtype=torch.long,
  515. device=self.device)
  516. if self.attn_backend.get_name() == "flashinfer":
  517. if not hasattr(self, "flashinfer_workspace_buffer"):
  518. # Allocate 16MB workspace buffer
  519. # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
  520. self.flashinfer_workspace_buffer = torch.empty(
  521. 16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
  522. paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
  523. dtype=torch.int,
  524. device=self.device)
  525. paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
  526. dtype=torch.int,
  527. device=self.device)
  528. paged_kv_last_page_len_tensor = torch.tensor(
  529. paged_kv_last_page_len, dtype=torch.int, device=self.device)
  530. kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
  531. self.model_config.dtype)
  532. attn_metadata = self.attn_backend.make_metadata(
  533. num_prefills=num_prefills,
  534. slot_mapping=slot_mapping_tensor,
  535. num_prefill_tokens=num_prefill_tokens,
  536. num_decode_tokens=num_decode_tokens,
  537. use_cuda_graph=False,
  538. max_prefill_seq_len=max_prefill_seq_len,
  539. block_tables=block_tables,
  540. workspace_buffer=self.flashinfer_workspace_buffer,
  541. paged_kv_indptr=paged_kv_indptr_tensor,
  542. paged_kv_indices=paged_kv_indices_tensor,
  543. paged_kv_last_page_len=paged_kv_last_page_len_tensor,
  544. num_qo_heads=self.model_config.get_num_attention_heads(
  545. self.parallel_config),
  546. num_kv_heads=self.model_config.get_num_kv_heads(
  547. self.parallel_config),
  548. head_dim=self.model_config.get_head_size(),
  549. page_size=16,
  550. seq_start_loc=seq_start_loc,
  551. data_type=kv_cache_dtype)
  552. else:
  553. attn_metadata = self.attn_backend.make_metadata(
  554. num_prefills=num_prefills,
  555. slot_mapping=slot_mapping_tensor,
  556. num_prefill_tokens=num_prefill_tokens,
  557. num_decode_tokens=num_decode_tokens,
  558. seq_lens=seq_lens,
  559. seq_lens_tensor=seq_lens_tensor,
  560. max_query_len=max_query_len,
  561. max_prefill_seq_len=max_prefill_seq_len,
  562. max_decode_seq_len=max_decode_seq_len,
  563. query_start_loc=query_start_loc,
  564. seq_start_loc=seq_start_loc,
  565. context_lens_tensor=context_lens_tensor,
  566. block_tables=block_tables,
  567. use_cuda_graph=use_captured_graph,
  568. )
  569. if self.lora_config:
  570. lora_mapping = LoRAMapping(
  571. lora_index_mapping,
  572. lora_prompt_mapping,
  573. )
  574. else:
  575. lora_mapping = None
  576. return ModelInput(
  577. input_tokens=input_tokens_tensor,
  578. input_positions=input_positions_tensor,
  579. attn_metadata=attn_metadata,
  580. seq_lens=seq_lens,
  581. query_lens=query_lens,
  582. lora_mapping=lora_mapping,
  583. lora_requests=lora_requests,
  584. multi_modal_input=multi_modal_input,
  585. slot_mapping=slot_mapping_tensor,
  586. num_prefill_tokens=num_prefill_tokens,
  587. num_decode_tokens=num_decode_tokens,
  588. num_prefills=num_prefills,
  589. )
  590. def prepare_input_tensors(
  591. self,
  592. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  593. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  594. Set[LoRARequest], LoRAMapping, torch.Tensor]:
  595. if self.is_driver_worker:
  596. assert seq_group_metadata_list is not None
  597. # Prepare input tensors.
  598. (
  599. input_tokens,
  600. input_positions,
  601. attn_metadata,
  602. seq_lens,
  603. query_lens,
  604. lora_mapping,
  605. lora_requests,
  606. multi_modal_input,
  607. slot_mapping,
  608. num_prefill_tokens,
  609. num_decode_tokens,
  610. num_prefills,
  611. ) = self._prepare_model_input(seq_group_metadata_list)
  612. sampling_metadata = SamplingMetadata.prepare(
  613. seq_group_metadata_list, seq_lens, query_lens, self.device,
  614. self.pin_memory)
  615. metadata_dict = {
  616. "input_tokens": input_tokens,
  617. "input_positions": input_positions,
  618. "selected_token_indices":
  619. sampling_metadata.selected_token_indices,
  620. "lora_requests": lora_requests,
  621. "lora_mapping": lora_mapping,
  622. "multi_modal_input": multi_modal_input,
  623. "num_prefill_tokens": num_prefill_tokens,
  624. "num_decode_tokens": num_decode_tokens,
  625. "slot_mapping": slot_mapping,
  626. "num_prefills": num_prefills,
  627. }
  628. if attn_metadata:
  629. metadata_dict.update(attn_metadata.asdict_zerocopy())
  630. broadcast_tensor_dict(metadata_dict, src=0)
  631. else:
  632. metadata_dict = broadcast_tensor_dict(src=0)
  633. input_tokens = metadata_dict.pop("input_tokens")
  634. input_positions = metadata_dict.pop("input_positions")
  635. selected_token_indices = metadata_dict.pop(
  636. "selected_token_indices")
  637. lora_mapping = metadata_dict.pop("lora_mapping")
  638. lora_requests = metadata_dict.pop("lora_requests")
  639. multi_modal_input = metadata_dict.pop("multi_modal_input")
  640. if metadata_dict:
  641. attn_metadata = self.attn_backend.make_metadata(
  642. **metadata_dict)
  643. else:
  644. attn_metadata = None
  645. sampling_metadata = SamplingMetadata(
  646. seq_groups=None,
  647. selected_token_indices=selected_token_indices,
  648. categorized_sample_indices=None,
  649. num_prompts=0,
  650. )
  651. return (input_tokens, input_positions, attn_metadata,
  652. sampling_metadata, lora_requests, lora_mapping,
  653. multi_modal_input)
  654. @torch.inference_mode()
  655. def execute_model(
  656. self,
  657. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  658. kv_caches: List[torch.Tensor],
  659. ) -> Optional[SamplerOutput]:
  660. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  661. lora_requests, lora_mapping, multi_modal_input
  662. ) = self.prepare_input_tensors(seq_group_metadata_list)
  663. if self.lora_config:
  664. self.set_active_loras(lora_requests, lora_mapping)
  665. # Currently cuda graph is only supported by the decode phase.
  666. prefill_meta = attn_metadata.prefill_metadata
  667. decode_meta = attn_metadata.decode_metadata
  668. if prefill_meta is None and decode_meta.use_cuda_graph:
  669. graph_batch_size = input_tokens.shape[0]
  670. model_executable = self.graph_runners[graph_batch_size]
  671. else:
  672. model_executable = self.model
  673. execute_model_kwargs = {
  674. "input_ids": input_tokens,
  675. "positions": input_positions,
  676. "kv_caches": kv_caches,
  677. "attn_metadata": attn_metadata,
  678. }
  679. if self.vision_language_config:
  680. execute_model_kwargs.update({"image_input": multi_modal_input})
  681. hidden_states = model_executable(**execute_model_kwargs)
  682. # Compute the logits.
  683. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  684. # Only perform sampling in the driver worker.
  685. if not self.is_driver_worker:
  686. return None
  687. # Sample the next token.
  688. output = self.model.sample(
  689. logits=logits,
  690. sampling_metadata=sampling_metadata,
  691. )
  692. return output
  693. @torch.inference_mode()
  694. def profile_run(self) -> None:
  695. # Enable top-k sampling to reflect the accurate memory usage.
  696. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  697. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  698. max_num_seqs = self.scheduler_config.max_num_seqs
  699. # This represents the maximum number of different requests
  700. # that will have unique loras, an therefore the max amount of memory
  701. # consumption create dummy lora request copies from the lora request
  702. # passed in, which contains a lora from the lora warmup path.
  703. dummy_lora_requests = []
  704. dummy_lora_requests_per_seq = []
  705. if self.lora_config:
  706. assert self.lora_manager is not None
  707. with self.lora_manager.dummy_lora_cache():
  708. for idx in range(self.lora_config.max_loras):
  709. lora_id = idx + 1
  710. dummy_lora_request = LoRARequest(
  711. lora_name=f"warmup_{lora_id}",
  712. lora_int_id=lora_id,
  713. lora_local_path="/not/a/real/path",
  714. )
  715. self.lora_manager.add_dummy_lora(dummy_lora_request,
  716. rank=LORA_WARMUP_RANK)
  717. dummy_lora_requests.append(dummy_lora_request)
  718. dummy_lora_requests_per_seq = [
  719. dummy_lora_requests[idx % len(dummy_lora_requests)]
  720. for idx in range(max_num_seqs)
  721. ]
  722. # Profile memory usage with max_num_sequences sequences and the total
  723. # number of tokens equal to max_num_batched_tokens.
  724. seqs: List[SequenceGroupMetadata] = []
  725. # Additional GPU memory may be needed for vision encoding, which needs
  726. # to be accounted for when calculating the GPU blocks for
  727. # Aphrodite blocker manager.
  728. # To exercise the worst scenario for GPU memory consumption,
  729. # the number of seqs (batch_size) is chosen to maximize the number
  730. # of images processed.
  731. if self.vision_language_config:
  732. max_num_seqs = min(
  733. max_num_seqs,
  734. int(max_num_batched_tokens /
  735. self.vision_language_config.image_feature_size))
  736. for group_id in range(max_num_seqs):
  737. seq_len = (max_num_batched_tokens // max_num_seqs +
  738. (group_id < max_num_batched_tokens % max_num_seqs))
  739. seq_data, fake_multi_modal_input = _prepare_fake_inputs(
  740. seq_len, self.vision_language_config)
  741. seq = SequenceGroupMetadata(
  742. request_id=str(group_id),
  743. is_prompt=True,
  744. seq_data={group_id: seq_data},
  745. sampling_params=sampling_params,
  746. block_tables=None,
  747. lora_request=dummy_lora_requests_per_seq[group_id]
  748. if dummy_lora_requests_per_seq else None,
  749. multi_modal_data=fake_multi_modal_input,
  750. )
  751. seqs.append(seq)
  752. # Run the model with the dummy inputs.
  753. num_layers = self.model_config.get_num_layers(self.parallel_config)
  754. kv_caches = [None] * num_layers
  755. self.execute_model(seqs, kv_caches)
  756. torch.cuda.synchronize()
  757. return
  758. def remove_all_loras(self):
  759. if not self.lora_manager:
  760. raise RuntimeError("LoRA is not enabled.")
  761. self.lora_manager.remove_all_loras()
  762. def set_active_loras(self, lora_requests: Set[LoRARequest],
  763. lora_mapping: LoRAMapping) -> None:
  764. if not self.lora_manager:
  765. raise RuntimeError("LoRA is not enabled.")
  766. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  767. def add_lora(self, lora_request: LoRARequest) -> bool:
  768. if not self.lora_manager:
  769. raise RuntimeError("LoRA is not enabled.")
  770. return self.lora_manager.add_lora(lora_request)
  771. def remove_lora(self, lora_id: int) -> bool:
  772. if not self.lora_manager:
  773. raise RuntimeError("LoRA is not enabled.")
  774. return self.lora_manager.remove_lora(lora_id)
  775. def list_loras(self) -> Set[int]:
  776. if not self.lora_manager:
  777. raise RuntimeError("LoRA is not enabled.")
  778. return self.lora_manager.list_loras()
  779. @torch.inference_mode()
  780. def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
  781. """Cuda graph capture a model.
  782. Note that CUDA graph's performance gain is negligible if number
  783. of batched tokens are larger than 200. And since CUDA graph
  784. requires fixed sized tensors, supporting large/variable batch
  785. size requires high GPU memory overhead. Thus, Aphrodite only captures
  786. decoding requests. Mixed batch (chunked prefill + decoding) or
  787. prefill requests are not captured.
  788. Since it is used for decoding-only, it assumes there's only 1 token
  789. per sequence in the batch.
  790. """
  791. assert not self.model_config.enforce_eager
  792. logger.info("Capturing the model for CUDA graphs. This may lead to "
  793. "unexpected consequences if the model is not static. To "
  794. "run the model in eager mode, set 'enforce_eager=True' or "
  795. "use '--enforce-eager' in the CLI.")
  796. logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
  797. "If you are running out of memory, consider decreasing "
  798. "`gpu_memory_utilization` or enforcing eager mode. "
  799. "You can also reduce the `max_num_seqs` as needed "
  800. "to decrease memory usage.")
  801. start_time = time.perf_counter()
  802. # Prepare dummy inputs. These will be reused for all batch sizes.
  803. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  804. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  805. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  806. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  807. slot_mapping.fill_(_PAD_SLOT_ID)
  808. seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  809. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  810. graph_batch_size = _get_graph_batch_size(
  811. self.scheduler_config.max_num_seqs)
  812. batch_size_capture_list = [
  813. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  814. ]
  815. with graph_capture() as graph_capture_context:
  816. # NOTE: Capturing the largest batch size first may help reduce the
  817. # memory usage of CUDA graph.
  818. for batch_size in reversed(batch_size_capture_list):
  819. # Create dummy attn_metadata.
  820. attn_metadata = self.attn_backend.make_metadata(
  821. num_prefills=0,
  822. num_prefill_tokens=0,
  823. num_decode_tokens=batch_size,
  824. slot_mapping=slot_mapping[:batch_size],
  825. seq_lens=None,
  826. seq_lens_tensor=seq_lens[:batch_size],
  827. max_query_len=None,
  828. max_prefill_seq_len=0,
  829. max_decode_seq_len=self.max_seq_len_to_capture,
  830. query_start_loc=None,
  831. seq_start_loc=None,
  832. context_lens_tensor=None,
  833. block_tables=block_tables[:batch_size],
  834. use_cuda_graph=True,
  835. )
  836. if self.lora_config:
  837. lora_mapping = LoRAMapping(
  838. [0] * batch_size,
  839. [0] * batch_size,
  840. )
  841. self.set_active_loras(set(), lora_mapping)
  842. graph_runner = CUDAGraphRunner(self.model)
  843. graph_runner.capture(
  844. input_tokens[:batch_size],
  845. input_positions[:batch_size],
  846. kv_caches,
  847. attn_metadata,
  848. memory_pool=self.graph_memory_pool,
  849. stream=graph_capture_context.stream,
  850. )
  851. self.graph_memory_pool = graph_runner.graph.pool()
  852. self.graph_runners[batch_size] = graph_runner
  853. end_time = time.perf_counter()
  854. elapsed_time = end_time - start_time
  855. # This usually takes < 10 seconds.
  856. logger.info(f"Graph capturing finished in {elapsed_time} secs.")
  857. @property
  858. def vocab_size(self) -> int:
  859. return self.model_config.get_vocab_size()
  860. class CUDAGraphRunner:
  861. def __init__(self, model: nn.Module):
  862. self.model = model
  863. self.input_buffers: Dict[str, torch.Tensor] = {}
  864. self.output_buffers: Dict[str, torch.Tensor] = {}
  865. self._graph: Optional[torch.cuda.CUDAGraph] = None
  866. @property
  867. def graph(self):
  868. assert self._graph is not None
  869. return self._graph
  870. def capture(
  871. self,
  872. input_ids: torch.Tensor,
  873. positions: torch.Tensor,
  874. kv_caches: List[torch.Tensor],
  875. attn_metadata: AttentionMetadata,
  876. memory_pool: Optional[Tuple[int, int]],
  877. stream: torch.cuda.Stream,
  878. **kwargs,
  879. ) -> None:
  880. assert self._graph is None
  881. # Run the model once without capturing the graph.
  882. # This is to make sure that the captured graph does not include the
  883. # kernel launches for initial benchmarking (e.g., Triton autotune).
  884. self.model(
  885. input_ids,
  886. positions,
  887. kv_caches,
  888. attn_metadata,
  889. **kwargs,
  890. )
  891. torch.cuda.synchronize()
  892. # Capture the graph.
  893. self._graph = torch.cuda.CUDAGraph()
  894. with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
  895. hidden_states = self.model(
  896. input_ids,
  897. positions,
  898. kv_caches,
  899. attn_metadata,
  900. **kwargs,
  901. )
  902. torch.cuda.synchronize()
  903. # Save the input and output buffers.
  904. self.input_buffers = {
  905. "input_ids": input_ids,
  906. "positions": positions,
  907. "kv_caches": kv_caches,
  908. "slot_mapping": attn_metadata.slot_mapping,
  909. "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
  910. "block_tables": attn_metadata.decode_metadata.block_tables,
  911. }
  912. self.output_buffers = {"hidden_states": hidden_states}
  913. return
  914. def forward(
  915. self,
  916. input_ids: torch.Tensor,
  917. positions: torch.Tensor,
  918. kv_caches: List[torch.Tensor],
  919. attn_metadata: AttentionMetadata,
  920. **kwargs,
  921. ) -> torch.Tensor:
  922. # KV caches are fixed tensors, so we don't need to copy them.
  923. del kv_caches
  924. # Copy the input tensors to the input buffers.
  925. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  926. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  927. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  928. non_blocking=True)
  929. self.input_buffers["seq_lens_tensor"].copy_(
  930. attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
  931. self.input_buffers["block_tables"].copy_(
  932. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  933. # Run the graph.
  934. self.graph.replay()
  935. # Return the output tensor.
  936. return self.output_buffers["hidden_states"]
  937. def __call__(self, *args, **kwargs):
  938. return self.forward(*args, **kwargs)
  939. def _get_graph_batch_size(batch_size: int) -> int:
  940. """Returns the padded batch size given actual batch size.
  941. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  942. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  943. """
  944. if batch_size <= 2:
  945. return batch_size
  946. elif batch_size <= 4:
  947. return 4
  948. else:
  949. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  950. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  951. def _prepare_fake_inputs(
  952. seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
  953. """Prepare fake inputs for profile run."""
  954. if vision_language_config:
  955. prompt_tokens = [
  956. vision_language_config.image_token_id
  957. ] * vision_language_config.image_feature_size + [0] * (
  958. seq_len - vision_language_config.image_feature_size)
  959. fake_image_input = MultiModalData(
  960. type=MultiModalData.Type.IMAGE,
  961. data=torch.zeros(vision_language_config.image_input_shape,
  962. dtype=torch.float16))
  963. else:
  964. prompt_tokens = [0] * seq_len
  965. fake_image_input = None
  966. return SequenceData(prompt_tokens), fake_image_input
  967. def _is_block_tables_empty(block_tables: Union[None, Dict]):
  968. """
  969. Check if block_tables is None or a dictionary with all None values.
  970. """
  971. if block_tables is None:
  972. return True
  973. if isinstance(block_tables, dict) and all(
  974. value is None for value in block_tables.values()):
  975. return True
  976. return False