model_runner.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011
  1. import contextlib
  2. import time
  3. from typing import Dict, List, Optional, Tuple, Set, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from loguru import logger
  8. from aphrodite.common.config import (
  9. DeviceConfig,
  10. ModelConfig,
  11. LoRAConfig,
  12. ParallelConfig,
  13. SchedulerConfig,
  14. )
  15. from aphrodite.common.logger import get_loading_progress_bar
  16. from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
  17. from aphrodite.modeling.megatron import cupy_utils
  18. from aphrodite.modeling.megatron.communication_op import broadcast_tensor_dict
  19. from aphrodite.modeling.megatron.parallel_state import (
  20. get_tensor_model_parallel_world_size,
  21. with_cupy_nccl_for_all_reduce,
  22. )
  23. from aphrodite.modeling.megatron import custom_all_reduce
  24. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  25. from aphrodite.common.sequence import (
  26. SamplerOutput,
  27. SequenceData,
  28. SequenceGroupMetadata,
  29. )
  30. from aphrodite.modeling.sampling_metadata import PersistentMetadata
  31. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  32. from aphrodite.lora.layers import LoRAMapping
  33. from aphrodite.lora.request import LoRARequest
  34. from aphrodite.common.utils import in_wsl, measure_cuda_memory
  35. KVCache = Tuple[torch.Tensor, torch.Tensor]
  36. _PAD_SLOT_ID = -1
  37. LORA_WARMUP_RANK = 8
  38. # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  39. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  40. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
  41. class ModelRunner:
  42. def __init__(
  43. self,
  44. model_config: ModelConfig,
  45. parallel_config: ParallelConfig,
  46. scheduler_config: SchedulerConfig,
  47. device_config: DeviceConfig,
  48. lora_config: Optional[LoRAConfig],
  49. kv_cache_dtype: Optional[str] = "auto",
  50. kv_quant_params_path: Optional[str] = None,
  51. is_driver_worker: bool = False,
  52. ):
  53. self.model_config = model_config
  54. self.parallel_config = parallel_config
  55. self.scheduler_config = scheduler_config
  56. self.lora_config = lora_config
  57. self.is_driver_worker = is_driver_worker
  58. # model_config can be None in tests/samplers/test_sampler.py.
  59. # FIXME: This is a hack to make the tests work. Refactor this.
  60. self.sliding_window = (model_config.get_sliding_window()
  61. if model_config is not None else None)
  62. self.device_config = (device_config
  63. if device_config is not None else DeviceConfig())
  64. self.device = self.device_config.device
  65. self.model = None
  66. self.block_size = None # Set after initial profiling.
  67. self.lora_manager = None
  68. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  69. self.graph_memory_pool = None # Set during graph capture.
  70. self.max_context_len_to_capture = (
  71. self.model_config.max_context_len_to_capture
  72. if self.model_config is not None else 0)
  73. # When using CUDA graph, the input block tables must be padded to
  74. # max_context_len_to_capture. However, creating the block table in
  75. # Python can be expensive. To optimize this, we cache the block table
  76. # in numpy and only copy the actual input content at every iteration.
  77. # The shape of the cached block table will be
  78. # (max batch size to capture, max context len to capture / block size).
  79. self.graph_block_tables = None # Set after initial profiling.
  80. # cache in_wsl result
  81. self.in_wsl = in_wsl()
  82. self.kv_cache_dtype = kv_cache_dtype
  83. self.kv_quant_params = (self.load_kv_quant_params(
  84. model_config, kv_quant_params_path)
  85. if self.kv_cache_dtype == "int8" else None)
  86. def load_kv_quant_params(self, model_config: ModelConfig,
  87. kv_quant_params_path: str) -> List[List[float]]:
  88. if model_config is None:
  89. return None
  90. # Remove it when all models support kv cache int8.
  91. architectures = model_config.hf_config.architectures
  92. for arch in architectures:
  93. if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
  94. raise ValueError(
  95. "KV CACHE INT8 is not supported for model architectures "
  96. f"{arch} for now. "
  97. "Supported architectures: LlamaForCausalLM and "
  98. "LLaMAForCausalLM.")
  99. num_layers = model_config.hf_config.num_hidden_layers
  100. kv_quant_params = []
  101. for i in range(num_layers):
  102. if kv_quant_params_path is not None:
  103. path = (kv_quant_params_path +
  104. f"/layers.{i}.past_kv_scale.0.weight")
  105. kv_quant_param = list(np.fromfile(path, dtype=np.float32))
  106. kv_quant_params.append(kv_quant_param)
  107. return kv_quant_params
  108. def load_model(self) -> None:
  109. with measure_cuda_memory() as m:
  110. self.model = get_model(self.model_config, self.device_config,
  111. self.lora_config)
  112. self.model_memory_usage = m.consumed_memory
  113. tp = get_tensor_model_parallel_world_size()
  114. logger.info(
  115. "Model weights loaded. Memory usage: "
  116. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
  117. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  118. vocab_size = self.model.config.vocab_size
  119. if self.lora_config:
  120. assert (hasattr(self.model, "supported_lora_modules")
  121. and self.model.supported_lora_modules
  122. ), "Model does not support LoRA"
  123. assert hasattr(
  124. self.model,
  125. "embedding_modules"), "Model does not have embedding_modules"
  126. assert hasattr(self.model, "embedding_padding_modules"
  127. ), "Model does not have embedding_padding_modules"
  128. self.lora_manager = LRUCacheWorkerLoRAManager(
  129. self.scheduler_config.max_num_seqs,
  130. self.scheduler_config.max_num_batched_tokens +
  131. self.scheduler_config.max_paddings,
  132. vocab_size,
  133. self.lora_config,
  134. self.device,
  135. self.model.embedding_modules,
  136. self.model.embedding_padding_modules,
  137. )
  138. self.model = self.lora_manager.create_lora_manager(self.model)
  139. def set_block_size(self, block_size: int) -> None:
  140. self.block_size = block_size
  141. max_num_blocks = (self.max_context_len_to_capture + block_size -
  142. 1) // block_size
  143. self.graph_block_tables = np.zeros(
  144. (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
  145. def _prepare_prompt(
  146. self,
  147. seq_group_metadata_list: List[SequenceGroupMetadata],
  148. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
  149. List[int], List[int], Set[LoRARequest], ]:
  150. assert len(seq_group_metadata_list) > 0
  151. input_tokens: List[List[int]] = []
  152. input_positions: List[List[int]] = []
  153. slot_mapping: List[List[int]] = []
  154. lora_index_mapping: List[int] = []
  155. lora_prompt_mapping: List[int] = []
  156. lora_requests: Set[LoRARequest] = set()
  157. prompt_lens: List[int] = []
  158. context_lens: List[int] = []
  159. subquery_lens: List[int] = []
  160. prefix_block_tables: List[List[int]] = []
  161. for seq_group_metadata in seq_group_metadata_list:
  162. assert seq_group_metadata.is_prompt
  163. seq_ids = list(seq_group_metadata.seq_data.keys())
  164. assert len(seq_ids) == 1
  165. seq_id = seq_ids[0]
  166. seq_data = seq_group_metadata.seq_data[seq_id]
  167. prompt_tokens = seq_data.get_token_ids()
  168. prompt_len = len(prompt_tokens)
  169. prompt_lens.append(prompt_len)
  170. computed_len = 0
  171. # NOTE: This only works for oooooooxxx style attention.
  172. computed_block_nums = seq_group_metadata.computed_block_nums
  173. if (computed_block_nums is not None
  174. and len(computed_block_nums) > 0
  175. and self.sliding_window is None):
  176. # Prefix is not supported with sliding_window
  177. computed_len = len(computed_block_nums) * self.block_size
  178. prompt_tokens = prompt_tokens[computed_len:]
  179. prefix_block_tables.append(computed_block_nums)
  180. else:
  181. prefix_block_tables.append([])
  182. # actual prompt lens
  183. context_lens.append(computed_len)
  184. subquery_lens.append(prompt_len - computed_len)
  185. input_tokens.append(prompt_tokens)
  186. # NOTE: Here we assume that the first token in the prompt
  187. # is always the first token in the sequence.
  188. input_positions.append(
  189. list(range(computed_len, computed_len + len(prompt_tokens))))
  190. lora_id = seq_group_metadata.lora_int_id
  191. if lora_id > 0:
  192. lora_requests.add(seq_group_metadata.lora_request)
  193. lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
  194. lora_prompt_mapping.extend(
  195. [lora_id] *
  196. (prompt_len - computed_len
  197. if seq_group_metadata.sampling_params.prompt_logprobs else 1))
  198. if seq_group_metadata.block_tables is None:
  199. # During memory profiling, the block tables are not initialized
  200. # yet. In this case, we just use a dummy slot mapping.
  201. slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
  202. continue
  203. # Compute the slot mapping.
  204. slot_mapping.append([])
  205. block_table = seq_group_metadata.block_tables[seq_id]
  206. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  207. # where start_idx is max(0, prompt_len - sliding_window).
  208. # For example, if the prompt len is 10, sliding window is 8, and
  209. # block size is 4, the first two tokens are masked and the slot
  210. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  211. start_idx = 0
  212. if self.sliding_window is not None:
  213. assert computed_len == 0, (
  214. "Prefix caching is currently not supported with "
  215. "sliding window attention")
  216. start_idx = max(0, prompt_len - self.sliding_window)
  217. for i in range(computed_len, prompt_len):
  218. if i < start_idx:
  219. slot_mapping[-1].append(_PAD_SLOT_ID)
  220. continue
  221. block_number = block_table[i // self.block_size]
  222. block_offset = i % self.block_size
  223. slot = block_number * self.block_size + block_offset
  224. slot_mapping[-1].append(slot)
  225. max_prompt_len = max(subquery_lens)
  226. input_tokens = _make_tensor_with_pad(
  227. input_tokens,
  228. max_prompt_len,
  229. pad=0,
  230. dtype=torch.long,
  231. device=self.device,
  232. )
  233. input_positions = _make_tensor_with_pad(
  234. input_positions,
  235. max_prompt_len,
  236. pad=0,
  237. dtype=torch.long,
  238. device=self.device,
  239. )
  240. slot_mapping = _make_tensor_with_pad(
  241. slot_mapping,
  242. max_prompt_len,
  243. pad=_PAD_SLOT_ID,
  244. dtype=torch.long,
  245. device=self.device,
  246. )
  247. lora_index_mapping = [
  248. _pad_to_max(mapping, max_prompt_len, pad=0)
  249. for mapping in lora_index_mapping
  250. ]
  251. context_lens_tensor = torch.tensor(context_lens,
  252. dtype=torch.int,
  253. device=self.device)
  254. # Prepare prefix block tables
  255. max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
  256. block_tables = _make_tensor_with_pad(
  257. prefix_block_tables,
  258. max_len=max_prompt_block_table_len,
  259. pad=0,
  260. dtype=torch.int,
  261. device=self.device,
  262. )
  263. start_loc_tensor = torch.arange(
  264. 0,
  265. len(prompt_lens) * max_prompt_len,
  266. max_prompt_len,
  267. dtype=torch.long,
  268. device=self.device,
  269. )
  270. prompt_lens_tensor = torch.tensor(prompt_lens,
  271. dtype=torch.long,
  272. device=self.device)
  273. input_metadata = InputMetadata(
  274. is_prompt=True,
  275. slot_mapping=slot_mapping,
  276. prompt_lens=prompt_lens_tensor,
  277. max_seq_len=max_prompt_len,
  278. start_loc=start_loc_tensor,
  279. max_context_len=None,
  280. context_lens=context_lens_tensor,
  281. block_tables=block_tables,
  282. use_cuda_graph=False,
  283. kv_cache_dtype=self.kv_cache_dtype,
  284. kv_quant_params=self.kv_quant_params,
  285. )
  286. return (
  287. input_tokens,
  288. input_positions,
  289. input_metadata,
  290. prompt_lens,
  291. subquery_lens,
  292. lora_index_mapping,
  293. lora_prompt_mapping,
  294. lora_requests,
  295. )
  296. def _prepare_decode(
  297. self,
  298. seq_group_metadata_list: List[SequenceGroupMetadata],
  299. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
  300. Set[LoRARequest], ]:
  301. assert len(seq_group_metadata_list) > 0
  302. input_tokens: List[List[int]] = []
  303. input_positions: List[List[int]] = []
  304. slot_mapping: List[List[int]] = []
  305. context_lens: List[int] = []
  306. block_tables: List[List[int]] = []
  307. lora_index_mapping: List[int] = []
  308. lora_prompt_mapping: List[int] = []
  309. lora_requests: Set[LoRARequest] = set()
  310. for seq_group_metadata in seq_group_metadata_list:
  311. assert not seq_group_metadata.is_prompt
  312. seq_ids = list(seq_group_metadata.seq_data.keys())
  313. lora_id = seq_group_metadata.lora_int_id
  314. if lora_id > 0:
  315. lora_requests.add(seq_group_metadata.lora_request)
  316. for seq_id in seq_ids:
  317. seq_data = seq_group_metadata.seq_data[seq_id]
  318. generation_token = seq_data.get_last_token_id()
  319. input_tokens.append([generation_token])
  320. seq_len = seq_data.get_len()
  321. position = seq_len - 1
  322. input_positions.append([position])
  323. context_len = (seq_len if self.sliding_window is None else min(
  324. seq_len, self.sliding_window))
  325. context_lens.append(context_len)
  326. block_table = seq_group_metadata.block_tables[seq_id]
  327. block_number = block_table[position // self.block_size]
  328. block_offset = position % self.block_size
  329. slot = block_number * self.block_size + block_offset
  330. slot_mapping.append([slot])
  331. lora_index_mapping.append([lora_id])
  332. lora_prompt_mapping.append(lora_id)
  333. if self.sliding_window is not None:
  334. sliding_window_blocks = (self.sliding_window //
  335. self.block_size)
  336. block_table = block_table[-sliding_window_blocks:]
  337. block_tables.append(block_table)
  338. batch_size = len(input_tokens)
  339. max_context_len = max(context_lens)
  340. use_captured_graph = (
  341. not self.model_config.enforce_eager
  342. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  343. and max_context_len <= self.max_context_len_to_capture)
  344. if use_captured_graph:
  345. # Pad the input tokens, positions, and slot mapping to match the
  346. # batch size of the captured graph.
  347. graph_batch_size = _get_graph_batch_size(batch_size)
  348. assert graph_batch_size >= batch_size
  349. for _ in range(graph_batch_size - batch_size):
  350. input_tokens.append([])
  351. input_positions.append([])
  352. slot_mapping.append([])
  353. context_lens.append(1)
  354. block_tables.append([])
  355. batch_size = graph_batch_size
  356. input_tokens = _make_tensor_with_pad(input_tokens,
  357. max_len=1,
  358. pad=0,
  359. dtype=torch.long,
  360. device=self.device)
  361. input_positions = _make_tensor_with_pad(
  362. input_positions,
  363. max_len=1,
  364. pad=0,
  365. dtype=torch.long,
  366. device=self.device,
  367. )
  368. slot_mapping = _make_tensor_with_pad(
  369. slot_mapping,
  370. max_len=1,
  371. pad=_PAD_SLOT_ID,
  372. dtype=torch.long,
  373. device=self.device,
  374. )
  375. context_lens = torch.tensor(context_lens,
  376. dtype=torch.int,
  377. device=self.device)
  378. if use_captured_graph:
  379. # The shape of graph_block_tables is
  380. # [max batch size, max context len // block size].
  381. input_block_tables = self.graph_block_tables[:batch_size]
  382. for i, block_table in enumerate(block_tables):
  383. if block_table:
  384. input_block_tables[i, :len(block_table)] = block_table
  385. block_tables = torch.tensor(input_block_tables, device=self.device)
  386. else:
  387. max_block_table_len = max(
  388. len(block_table) for block_table in block_tables)
  389. block_tables = _make_tensor_with_pad(
  390. block_tables,
  391. max_len=max_block_table_len,
  392. pad=0,
  393. dtype=torch.int,
  394. device=self.device,
  395. )
  396. lora_index_mapping = [
  397. _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
  398. ]
  399. input_metadata = InputMetadata(
  400. is_prompt=False,
  401. slot_mapping=slot_mapping,
  402. prompt_lens=None,
  403. max_seq_len=None,
  404. start_loc=None,
  405. max_context_len=max_context_len,
  406. context_lens=context_lens,
  407. block_tables=block_tables,
  408. use_cuda_graph=use_captured_graph,
  409. kv_cache_dtype=self.kv_cache_dtype,
  410. kv_quant_params=self.kv_quant_params,
  411. )
  412. return (
  413. input_tokens,
  414. input_positions,
  415. input_metadata,
  416. lora_index_mapping,
  417. lora_prompt_mapping,
  418. lora_requests,
  419. )
  420. def _prepare_sample(
  421. self,
  422. seq_group_metadata_list: List[SequenceGroupMetadata],
  423. prompt_lens: List[int],
  424. subquery_lens: Optional[List[int]],
  425. ) -> SamplingMetadata:
  426. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  427. selected_token_indices: List[int] = []
  428. generators: List[torch.Generator] = []
  429. selected_token_start_idx = 0
  430. categorized_sample_indices = {t: [] for t in SamplingType}
  431. categorized_sample_indices_start_idx = 0
  432. max_subquery_len = max(subquery_lens) if subquery_lens else 1
  433. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  434. seq_ids = list(seq_group_metadata.seq_data.keys())
  435. sampling_params = seq_group_metadata.sampling_params
  436. seq_groups.append((seq_ids, sampling_params))
  437. if seq_group_metadata.is_prompt:
  438. assert len(seq_ids) == 1
  439. assert subquery_lens is not None
  440. subquery_len = subquery_lens[i]
  441. if sampling_params.prompt_logprobs is not None:
  442. # NOTE: prompt token positions do not need sample, skip
  443. categorized_sample_indices_start_idx += subquery_len - 1
  444. categorized_sample_indices[
  445. sampling_params.sampling_type].append(
  446. categorized_sample_indices_start_idx)
  447. categorized_sample_indices_start_idx += 1
  448. if sampling_params.prompt_logprobs is not None:
  449. selected_token_indices.extend(
  450. range(
  451. selected_token_start_idx,
  452. selected_token_start_idx + subquery_len - 1,
  453. ))
  454. selected_token_indices.append(selected_token_start_idx +
  455. subquery_len - 1)
  456. selected_token_start_idx += max_subquery_len
  457. if sampling_params.seed is not None:
  458. seq_group_metadata.state.generator = torch.Generator(
  459. device="cuda").manual_seed(sampling_params.seed)
  460. else:
  461. num_seqs = len(seq_ids)
  462. selected_token_indices.extend(
  463. range(
  464. selected_token_start_idx,
  465. selected_token_start_idx + num_seqs,
  466. ))
  467. selected_token_start_idx += num_seqs
  468. categorized_sample_indices[
  469. sampling_params.sampling_type].extend(
  470. range(
  471. categorized_sample_indices_start_idx,
  472. categorized_sample_indices_start_idx + num_seqs,
  473. ))
  474. categorized_sample_indices_start_idx += num_seqs
  475. if sampling_params.seed is not None:
  476. generators.append(seq_group_metadata.state.generator)
  477. selected_token_indices = _async_h2d(
  478. selected_token_indices,
  479. dtype=torch.long,
  480. target_device=self.device,
  481. pin_memory=not self.in_wsl,
  482. )
  483. categorized_sample_indices = {
  484. t: _async_h2d(
  485. seq_ids,
  486. dtype=torch.int,
  487. target_device=self.device,
  488. pin_memory=not self.in_wsl,
  489. )
  490. for t, seq_ids in categorized_sample_indices.items()
  491. }
  492. seq_data: Dict[int, SequenceData] = {}
  493. for seq_group_metadata in seq_group_metadata_list:
  494. seq_data.update(seq_group_metadata.seq_data)
  495. seq_persistence_data: Dict[int, dict] = {}
  496. for grp in seq_group_metadata_list:
  497. seq_persistence_data.update(grp.persistent_data)
  498. sampling_metadata = SamplingMetadata(
  499. seq_groups=seq_groups,
  500. seq_data=seq_data,
  501. prompt_lens=prompt_lens,
  502. selected_token_indices=selected_token_indices,
  503. categorized_sample_indices=categorized_sample_indices,
  504. generators=generators,
  505. persistent_metadata=PersistentMetadata(seq_persistence_data),
  506. )
  507. return sampling_metadata
  508. def prepare_input_tensors(
  509. self,
  510. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  511. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
  512. Set[int], LoRAMapping, ]:
  513. if self.is_driver_worker:
  514. # NOTE: We assume that all sequences in the group are all prompts or
  515. # all decodes.
  516. is_prompt = seq_group_metadata_list[0].is_prompt
  517. # Prepare input tensors.
  518. if is_prompt:
  519. (
  520. input_tokens,
  521. input_positions,
  522. input_metadata,
  523. prompt_lens,
  524. subquery_lens,
  525. lora_index_mapping,
  526. lora_prompt_mapping,
  527. lora_requests,
  528. ) = self._prepare_prompt(seq_group_metadata_list)
  529. else:
  530. (
  531. input_tokens,
  532. input_positions,
  533. input_metadata,
  534. lora_index_mapping,
  535. lora_prompt_mapping,
  536. lora_requests,
  537. ) = self._prepare_decode(seq_group_metadata_list)
  538. prompt_lens = []
  539. subquery_lens = None
  540. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  541. prompt_lens,
  542. subquery_lens)
  543. if self.lora_config:
  544. flat_lora_index_mapping = [
  545. item for sublist in lora_index_mapping for item in sublist
  546. ]
  547. lora_mapping = LoRAMapping(
  548. flat_lora_index_mapping,
  549. lora_prompt_mapping,
  550. )
  551. else:
  552. lora_mapping = None
  553. # Broadcast the metadata.
  554. metadata_dict = {
  555. "input_tokens": input_tokens,
  556. "input_positions": input_positions,
  557. "is_prompt": input_metadata.is_prompt,
  558. "slot_mapping": input_metadata.slot_mapping,
  559. "prompt_lens": input_metadata.prompt_lens,
  560. "max_seq_len": input_metadata.max_seq_len,
  561. "start_loc": input_metadata.start_loc,
  562. "max_context_len": input_metadata.max_context_len,
  563. "context_lens": input_metadata.context_lens,
  564. "block_tables": input_metadata.block_tables,
  565. "use_cuda_graph": input_metadata.use_cuda_graph,
  566. "kv_cache_dtype": input_metadata.kv_cache_dtype,
  567. "kv_quant_params": input_metadata.kv_quant_params,
  568. "selected_token_indices":
  569. sampling_metadata.selected_token_indices, # noqa
  570. "lora_requests": lora_requests,
  571. "lora_mapping": lora_mapping,
  572. }
  573. broadcast_tensor_dict(metadata_dict, src=0)
  574. else:
  575. metadata_dict = broadcast_tensor_dict(src=0)
  576. input_tokens = metadata_dict["input_tokens"]
  577. input_positions = metadata_dict["input_positions"]
  578. lora_mapping = metadata_dict["lora_mapping"]
  579. lora_requests = metadata_dict["lora_requests"]
  580. input_metadata = InputMetadata(
  581. is_prompt=metadata_dict["is_prompt"],
  582. slot_mapping=metadata_dict["slot_mapping"],
  583. prompt_lens=metadata_dict["prompt_lens"],
  584. max_seq_len=metadata_dict["max_seq_len"],
  585. start_loc=metadata_dict["start_loc"],
  586. max_context_len=metadata_dict["max_context_len"],
  587. context_lens=metadata_dict["context_lens"],
  588. block_tables=metadata_dict["block_tables"],
  589. use_cuda_graph=metadata_dict["use_cuda_graph"],
  590. kv_cache_dtype=metadata_dict["kv_cache_dtype"],
  591. kv_quant_params=metadata_dict["kv_quant_params"],
  592. )
  593. sampling_metadata = SamplingMetadata(
  594. seq_groups=None,
  595. seq_data=None,
  596. prompt_lens=None,
  597. selected_token_indices=metadata_dict["selected_token_indices"],
  598. categorized_sample_indices=None,
  599. generators=None,
  600. perform_sampling=False,
  601. )
  602. return (
  603. input_tokens,
  604. input_positions,
  605. input_metadata,
  606. sampling_metadata,
  607. lora_requests,
  608. lora_mapping,
  609. )
  610. @torch.inference_mode()
  611. def execute_model(
  612. self,
  613. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  614. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  615. ) -> Optional[SamplerOutput]:
  616. (
  617. input_tokens,
  618. input_positions,
  619. input_metadata,
  620. sampling_metadata,
  621. lora_requests,
  622. lora_mapping,
  623. ) = self.prepare_input_tensors(seq_group_metadata_list)
  624. if self.lora_config:
  625. self.set_active_loras(lora_requests, lora_mapping)
  626. # Execute the model.
  627. if input_metadata.use_cuda_graph:
  628. graph_batch_size = input_tokens.shape[0]
  629. model_executable = self.graph_runners[graph_batch_size]
  630. else:
  631. model_executable = self.model
  632. hidden_states = model_executable(
  633. input_ids=input_tokens,
  634. positions=input_positions,
  635. kv_caches=kv_caches,
  636. input_metadata=input_metadata,
  637. )
  638. # Sample the next token.
  639. output = self.model.sample(
  640. hidden_states=hidden_states,
  641. sampling_metadata=sampling_metadata,
  642. )
  643. return output
  644. @torch.inference_mode()
  645. def profile_run(self) -> None:
  646. # Enable top-k sampling to reflect the accurate memory usage.
  647. vocab_size = self.model_config.get_vocab_size()
  648. sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
  649. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  650. max_num_seqs = self.scheduler_config.max_num_seqs
  651. # This represents the maximum number of different requests
  652. # that will have unique loras, an therefore the max amount of memory
  653. # consumption create dummy lora request copies from the lora request
  654. # passed in, which contains a lora from the lora warmup path.
  655. dummy_lora_requests = []
  656. dummy_lora_requests_per_seq = []
  657. if self.lora_config:
  658. for idx in range(self.lora_config.max_loras):
  659. lora_id = idx + 1
  660. dummy_lora_request = LoRARequest(
  661. lora_name=f"warmup_{lora_id}",
  662. lora_int_id=lora_id,
  663. lora_local_path="/not/a/real/path",
  664. )
  665. self.lora_manager.add_dummy_lora(dummy_lora_request,
  666. rank=LORA_WARMUP_RANK)
  667. dummy_lora_requests.append(dummy_lora_request)
  668. dummy_lora_requests_per_seq = [
  669. dummy_lora_requests[idx % len(dummy_lora_requests)]
  670. for idx in range(max_num_seqs)
  671. ]
  672. # Profile memory usage with max_num_sequences sequences and the total
  673. # number of tokens equal to max_num_batched_tokens.
  674. seqs: List[SequenceGroupMetadata] = []
  675. for group_id in range(max_num_seqs):
  676. seq_len = max_num_batched_tokens // max_num_seqs + (
  677. group_id < max_num_batched_tokens % max_num_seqs)
  678. seq_data = SequenceData([0] * seq_len)
  679. seq = SequenceGroupMetadata(
  680. request_id=str(group_id),
  681. is_prompt=True,
  682. seq_data={group_id: seq_data},
  683. sampling_params=sampling_params,
  684. block_tables=None,
  685. persistent_data={},
  686. lora_request=dummy_lora_requests_per_seq[group_id]
  687. if dummy_lora_requests_per_seq else None,
  688. )
  689. seqs.append(seq)
  690. # Run the model with the dummy inputs.
  691. num_layers = self.model_config.get_num_layers(self.parallel_config)
  692. kv_caches = [(None, None)] * num_layers
  693. self.execute_model(seqs, kv_caches)
  694. torch.cuda.synchronize()
  695. return
  696. def remove_all_loras(self) -> bool:
  697. if not self.lora_manager:
  698. raise RuntimeError("LoRA is not enabled.")
  699. return self.lora_manager.remove_all_loras()
  700. def set_active_loras(self, lora_requests: List[LoRARequest],
  701. lora_mapping: LoRAMapping) -> None:
  702. if not self.lora_manager:
  703. raise RuntimeError("LoRA is not enabled.")
  704. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  705. def add_lora(self, lora_request: LoRARequest) -> bool:
  706. if not self.lora_manager:
  707. raise RuntimeError("LoRA is not enabled.")
  708. return self.lora_manager.add_lora(lora_request)
  709. def remove_lora(self, lora_id: int) -> bool:
  710. if not self.lora_manager:
  711. raise RuntimeError("LoRA is not enabled.")
  712. return self.lora_manager.remove_lora(lora_id)
  713. def list_loras(self) -> Set[int]:
  714. if not self.lora_manager:
  715. raise RuntimeError("LoRA is not enabled.")
  716. return self.lora_manager.list_loras()
  717. @torch.inference_mode()
  718. def capture_model(self, kv_caches: List[KVCache]) -> None:
  719. # NOTE: This is a hack to ensure that the NCCL backend is never
  720. # deleted before the CUDA graph
  721. self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
  722. assert not self.model_config.enforce_eager
  723. logger.info("Capturing the model for CUDA graphs. This may lead to "
  724. "unexpected consequences if the model is not static. To "
  725. "run the model in eager mode, set 'enforce_eager=True' or "
  726. "use '--enforce-eager' in the CLI.")
  727. logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
  728. "per GPU. If you are running out of memory, consider "
  729. "decreasing `gpu_memory_utilization` or enforcing "
  730. "eager mode.")
  731. start_time = time.perf_counter()
  732. # Prepare dummy inputs. These will be reused for all batch sizes.
  733. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  734. input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
  735. input_positions = torch.zeros(max_batch_size, 1,
  736. dtype=torch.long).cuda()
  737. slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
  738. slot_mapping.fill_(_PAD_SLOT_ID)
  739. context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  740. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  741. graph_batch_size = _get_graph_batch_size(
  742. self.scheduler_config.max_num_seqs)
  743. batch_size_capture_list = [
  744. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  745. ]
  746. # NOTE: Capturing the largest batch size first may help reduce the
  747. # memory usage of CUDA graph.
  748. # NOTE: There are 3 backends for all-reduce: custom all-reduce
  749. # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
  750. # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
  751. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
  752. # We always prioritize using custom all-reduce kernel but fall back
  753. # to PyTorch or CuPy NCCL if it is disabled or not supported.
  754. # Initialize a new progress bar
  755. progress = get_loading_progress_bar()
  756. task = progress.add_task("[cyan]Capturing graph...",
  757. total=len(batch_size_capture_list))
  758. with progress, custom_all_reduce.capture():
  759. for batch_size in reversed(batch_size_capture_list):
  760. if batch_size > self.scheduler_config.max_num_seqs:
  761. continue
  762. # Create dummy input_metadata.
  763. input_metadata = InputMetadata(
  764. is_prompt=False,
  765. slot_mapping=slot_mapping[:batch_size],
  766. prompt_lens=None,
  767. max_seq_len=None,
  768. start_loc=None,
  769. max_context_len=self.max_context_len_to_capture,
  770. context_lens=context_lens[:batch_size],
  771. block_tables=block_tables[:batch_size],
  772. use_cuda_graph=True,
  773. kv_cache_dtype=self.kv_cache_dtype,
  774. kv_quant_params=self.kv_quant_params,
  775. )
  776. if self.lora_config:
  777. lora_mapping = LoRAMapping(
  778. [0] * batch_size,
  779. [0] * batch_size,
  780. )
  781. self.set_active_loras(set(), lora_mapping)
  782. graph_runner = CUDAGraphRunner(self.model)
  783. graph_runner.capture(
  784. input_tokens[:batch_size],
  785. input_positions[:batch_size],
  786. kv_caches,
  787. input_metadata,
  788. memory_pool=self.graph_memory_pool,
  789. )
  790. self.graph_memory_pool = graph_runner.graph.pool()
  791. self.graph_runners[batch_size] = graph_runner
  792. # Update the progress bar
  793. progress.update(task, advance=1)
  794. end_time = time.perf_counter()
  795. elapsed_time = end_time - start_time
  796. # This usually takes < 10 seconds.
  797. logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
  798. def __del__(self) -> None:
  799. # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
  800. # NOTE: This is necessary because otherwise deadlocks can
  801. # happen.
  802. # FIXME: This is a bit hacky. Find a more robust solution.
  803. self.graph_runners.clear()
  804. self.cupy_nccl_backend = None
  805. class CUDAGraphRunner:
  806. def __init__(self, model: nn.Module):
  807. self.model = model
  808. self.graph = None
  809. self.input_buffers: Dict[str, torch.Tensor] = {}
  810. self.output_buffers: Dict[str, torch.Tensor] = {}
  811. def capture(
  812. self,
  813. input_ids: torch.Tensor,
  814. positions: torch.Tensor,
  815. kv_caches: List[KVCache],
  816. input_metadata: InputMetadata,
  817. memory_pool,
  818. ) -> None:
  819. assert self.graph is None
  820. # Run the model once without capturing the graph.
  821. # This is to make sure that the captured graph does not include the
  822. # kernel launches for initial benchmarking (e.g., Triton autotune).
  823. with _maybe_cupy_nccl():
  824. self.model(
  825. input_ids,
  826. positions,
  827. kv_caches,
  828. input_metadata,
  829. )
  830. torch.cuda.synchronize()
  831. # Capture the graph.
  832. # NOTE: Python 3.8 does not support multi-line with statements.
  833. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
  834. self.graph = torch.cuda.CUDAGraph()
  835. with torch.cuda.graph(self.graph,
  836. pool=memory_pool), _maybe_cupy_nccl():
  837. hidden_states = self.model(
  838. input_ids,
  839. positions,
  840. kv_caches,
  841. input_metadata,
  842. )
  843. torch.cuda.synchronize()
  844. # Save the input and output buffers.
  845. self.input_buffers = {
  846. "input_ids": input_ids,
  847. "positions": positions,
  848. "kv_caches": kv_caches,
  849. "slot_mapping": input_metadata.slot_mapping,
  850. "context_lens": input_metadata.context_lens,
  851. "block_tables": input_metadata.block_tables,
  852. }
  853. self.output_buffers = {"hidden_states": hidden_states}
  854. return
  855. def forward(
  856. self,
  857. input_ids: torch.Tensor,
  858. positions: torch.Tensor,
  859. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  860. input_metadata: InputMetadata,
  861. ) -> torch.Tensor:
  862. # KV caches are fixed tensors, so we don't need to copy them.
  863. del kv_caches
  864. # Copy the input tensors to the input buffers.
  865. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  866. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  867. self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
  868. non_blocking=True)
  869. self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
  870. non_blocking=True)
  871. self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
  872. non_blocking=True)
  873. # Run the graph.
  874. self.graph.replay()
  875. # Return the output tensor.
  876. return self.output_buffers["hidden_states"]
  877. def __call__(self, *args, **kwargs):
  878. return self.forward(*args, **kwargs)
  879. @contextlib.contextmanager
  880. def _maybe_cupy_nccl():
  881. if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
  882. with with_cupy_nccl_for_all_reduce():
  883. yield
  884. else:
  885. yield
  886. def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
  887. assert len(x) <= max_len
  888. return x + [pad] * (max_len - len(x))
  889. def _make_tensor_with_pad(
  890. x: List[List[int]],
  891. max_len: int,
  892. pad: int,
  893. dtype: torch.dtype,
  894. device: Optional[Union[str, torch.device]],
  895. ) -> torch.Tensor:
  896. padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
  897. return torch.tensor(padded_x, dtype=dtype, device=device)
  898. def _get_graph_batch_size(batch_size: int) -> int:
  899. if batch_size <= 2:
  900. return batch_size
  901. elif batch_size <= 4:
  902. return 4
  903. else:
  904. return (batch_size + 7) // 8 * 8
  905. def _async_h2d(
  906. data: list,
  907. dtype: torch.dtype,
  908. target_device: Union[str, torch.device],
  909. pin_memory: bool,
  910. ) -> torch.Tensor:
  911. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  912. return t.to(device=target_device, non_blocking=True)