model_runner.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935
  1. import contextlib
  2. import time
  3. from typing import Dict, List, Optional, Tuple, Set, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from loguru import logger
  8. from aphrodite.common.config import (DeviceConfig, ModelConfig, LoRAConfig,
  9. ParallelConfig, SchedulerConfig)
  10. from aphrodite.common.logger import get_loading_progress_bar
  11. from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
  12. from aphrodite.modeling.megatron import cupy_utils
  13. from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
  14. )
  15. from aphrodite.modeling.megatron.parallel_state import (
  16. with_cupy_nccl_for_all_reduce)
  17. from aphrodite.modeling.megatron import custom_all_reduce
  18. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  19. from aphrodite.common.sequence import (SamplerOutput, SequenceData,
  20. SequenceGroupMetadata)
  21. from aphrodite.modeling.sampling_metadata import PersistentMetadata
  22. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  23. from aphrodite.lora.layers import LoRAMapping
  24. from aphrodite.lora.request import LoRARequest
  25. from aphrodite.common.utils import in_wsl, measure_cuda_memory
  26. KVCache = Tuple[torch.Tensor, torch.Tensor]
  27. _PAD_SLOT_ID = -1
  28. LORA_WARMUP_RANK = 8
  29. # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  30. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  31. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
  32. class ModelRunner:
  33. def __init__(
  34. self,
  35. model_config: ModelConfig,
  36. parallel_config: ParallelConfig,
  37. scheduler_config: SchedulerConfig,
  38. device_config: DeviceConfig,
  39. lora_config: Optional[LoRAConfig],
  40. kv_cache_dtype: Optional[str] = "auto",
  41. kv_quant_params_path: Optional[str] = None,
  42. is_driver_worker: bool = False,
  43. ):
  44. self.model_config = model_config
  45. self.parallel_config = parallel_config
  46. self.scheduler_config = scheduler_config
  47. self.lora_config = lora_config
  48. self.is_driver_worker = is_driver_worker
  49. # model_config can be None in tests/samplers/test_sampler.py.
  50. # FIXME: This is a hack to make the tests work. Refactor this.
  51. self.sliding_window = (model_config.get_sliding_window()
  52. if model_config is not None else None)
  53. self.device_config = (device_config
  54. if device_config is not None else DeviceConfig())
  55. self.device = self.device_config.device
  56. self.model = None
  57. self.block_size = None # Set after initial profiling.
  58. self.lora_manager = None
  59. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  60. self.graph_memory_pool = None # Set during graph capture.
  61. self.max_context_len_to_capture = (
  62. self.model_config.max_context_len_to_capture
  63. if self.model_config is not None else 0)
  64. # When using CUDA graph, the input block tables must be padded to
  65. # max_context_len_to_capture. However, creating the block table in
  66. # Python can be expensive. To optimize this, we cache the block table
  67. # in numpy and only copy the actual input content at every iteration.
  68. # The shape of the cached block table will be
  69. # (max batch size to capture, max context len to capture / block size).
  70. self.graph_block_tables = None # Set after initial profiling.
  71. # cache in_wsl result
  72. self.in_wsl = in_wsl()
  73. self.kv_cache_dtype = kv_cache_dtype
  74. self.kv_quant_params = self.load_kv_quant_params(
  75. model_config,
  76. kv_quant_params_path) if self.kv_cache_dtype == "int8" else None
  77. def load_kv_quant_params(self, model_config: ModelConfig,
  78. kv_quant_params_path: str) -> List[List[float]]:
  79. if model_config is None:
  80. return None
  81. # Remove it when all models support kv cache int8.
  82. architectures = model_config.hf_config.architectures
  83. for arch in architectures:
  84. if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
  85. raise ValueError(
  86. f"KV CACHE INT8 is not supported for model architectures {arch} for now. "
  87. f"Supported architectures: LlamaForCausalLM and LLaMAForCausalLM."
  88. )
  89. num_layers = model_config.hf_config.num_hidden_layers
  90. kv_quant_params = []
  91. for i in range(num_layers):
  92. if kv_quant_params_path is not None:
  93. path = kv_quant_params_path + f"/layers.{i}.past_kv_scale.0.weight"
  94. kv_quant_param = list(np.fromfile(path, dtype=np.float32))
  95. kv_quant_params.append(kv_quant_param)
  96. return kv_quant_params
  97. def load_model(self) -> None:
  98. with measure_cuda_memory() as m:
  99. self.model = get_model(self.model_config, self.device_config,
  100. self.lora_config)
  101. self.model_memory_usage = m.consumed_memory
  102. logger.info("Model loaded. Memory usage: "
  103. f"{self.model_memory_usage / float(2**30):.2f} GiB")
  104. vocab_size = self.model.config.vocab_size
  105. if self.lora_config:
  106. assert hasattr(
  107. self.model, "supported_lora_modules"
  108. ) and self.model.supported_lora_modules, "Model does not support LoRA"
  109. assert hasattr(
  110. self.model,
  111. "embedding_modules"), "Model does not have embedding_modules"
  112. assert hasattr(self.model, "embedding_padding_modules"
  113. ), "Model does not have embedding_padding_modules"
  114. self.lora_manager = LRUCacheWorkerLoRAManager(
  115. self.scheduler_config.max_num_seqs,
  116. self.scheduler_config.max_num_batched_tokens +
  117. self.scheduler_config.max_paddings, vocab_size,
  118. self.lora_config, self.device, self.model.embedding_modules,
  119. self.model.embedding_padding_modules)
  120. self.model = self.lora_manager.create_lora_manager(self.model)
  121. def set_block_size(self, block_size: int) -> None:
  122. self.block_size = block_size
  123. max_num_blocks = (self.max_context_len_to_capture + block_size -
  124. 1) // block_size
  125. self.graph_block_tables = np.zeros(
  126. (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
  127. def _prepare_prompt(
  128. self,
  129. seq_group_metadata_list: List[SequenceGroupMetadata],
  130. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
  131. List[int], List[int], Set[LoRARequest]]:
  132. assert len(seq_group_metadata_list) > 0
  133. input_tokens: List[List[int]] = []
  134. input_positions: List[List[int]] = []
  135. slot_mapping: List[List[int]] = []
  136. lora_index_mapping: List[int] = []
  137. lora_prompt_mapping: List[int] = []
  138. lora_requests: Set[LoRARequest] = set()
  139. prompt_lens: List[int] = []
  140. context_lens: List[int] = []
  141. subquery_lens: List[int] = []
  142. prefix_block_tables: List[List[int]] = []
  143. for seq_group_metadata in seq_group_metadata_list:
  144. assert seq_group_metadata.is_prompt
  145. seq_ids = list(seq_group_metadata.seq_data.keys())
  146. assert len(seq_ids) == 1
  147. seq_id = seq_ids[0]
  148. seq_data = seq_group_metadata.seq_data[seq_id]
  149. prompt_tokens = seq_data.get_token_ids()
  150. prompt_len = len(prompt_tokens)
  151. prompt_lens.append(prompt_len)
  152. computed_len = 0
  153. # NOTE: This only works for oooooooxxx style attention.
  154. computed_block_nums = seq_group_metadata.computed_block_nums
  155. if computed_block_nums is not None and len(
  156. computed_block_nums) > 0 and self.sliding_window is None:
  157. # Prefix is not supported with sliding_window
  158. computed_len = len(computed_block_nums) * self.block_size
  159. prompt_tokens = prompt_tokens[computed_len:]
  160. prefix_block_tables.append(computed_block_nums)
  161. else:
  162. prefix_block_tables.append([])
  163. # actual prompt lens
  164. context_lens.append(computed_len)
  165. subquery_lens.append(prompt_len - computed_len)
  166. input_tokens.append(prompt_tokens)
  167. # NOTE: Here we assume that the first token in the prompt
  168. # is always the first token in the sequence.
  169. input_positions.append(
  170. list(range(computed_len, computed_len + len(prompt_tokens))))
  171. lora_id = seq_group_metadata.lora_int_id
  172. if lora_id > 0:
  173. lora_requests.add(seq_group_metadata.lora_request)
  174. lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
  175. lora_prompt_mapping.extend(
  176. [lora_id] *
  177. (prompt_len - computed_len
  178. if seq_group_metadata.sampling_params.prompt_logprobs else 1))
  179. if seq_group_metadata.block_tables is None:
  180. # During memory profiling, the block tables are not initialized
  181. # yet. In this case, we just use a dummy slot mapping.
  182. slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
  183. continue
  184. # Compute the slot mapping.
  185. slot_mapping.append([])
  186. block_table = seq_group_metadata.block_tables[seq_id]
  187. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  188. # where start_idx is max(0, prompt_len - sliding_window).
  189. # For example, if the prompt len is 10, sliding window is 8, and
  190. # block size is 4, the first two tokens are masked and the slot
  191. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  192. start_idx = 0
  193. if self.sliding_window is not None:
  194. assert computed_len == 0, (
  195. "Prefix caching is currently not supported with "
  196. "sliding window attention")
  197. start_idx = max(0, prompt_len - self.sliding_window)
  198. for i in range(computed_len, prompt_len):
  199. if i < start_idx:
  200. slot_mapping[-1].append(_PAD_SLOT_ID)
  201. continue
  202. block_number = block_table[i // self.block_size]
  203. block_offset = i % self.block_size
  204. slot = block_number * self.block_size + block_offset
  205. slot_mapping[-1].append(slot)
  206. max_prompt_len = max(subquery_lens)
  207. input_tokens = _make_tensor_with_pad(input_tokens,
  208. max_prompt_len,
  209. pad=0,
  210. dtype=torch.long,
  211. device=self.device)
  212. input_positions = _make_tensor_with_pad(input_positions,
  213. max_prompt_len,
  214. pad=0,
  215. dtype=torch.long,
  216. device=self.device)
  217. slot_mapping = _make_tensor_with_pad(slot_mapping,
  218. max_prompt_len,
  219. pad=_PAD_SLOT_ID,
  220. dtype=torch.long,
  221. device=self.device)
  222. lora_index_mapping = [
  223. _pad_to_max(mapping, max_prompt_len, pad=0)
  224. for mapping in lora_index_mapping
  225. ]
  226. context_lens_tensor = torch.tensor(context_lens,
  227. dtype=torch.int,
  228. device=self.device)
  229. # Prepare prefix block tables
  230. max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
  231. block_tables = _make_tensor_with_pad(
  232. prefix_block_tables,
  233. max_len=max_prompt_block_table_len,
  234. pad=0,
  235. dtype=torch.int,
  236. device=self.device,
  237. )
  238. start_loc_tensor = torch.arange(0,
  239. len(prompt_lens) * max_prompt_len,
  240. max_prompt_len,
  241. dtype=torch.long,
  242. device=self.device)
  243. prompt_lens_tensor = torch.tensor(prompt_lens,
  244. dtype=torch.long,
  245. device=self.device)
  246. input_metadata = InputMetadata(
  247. is_prompt=True,
  248. slot_mapping=slot_mapping,
  249. prompt_lens=prompt_lens_tensor,
  250. max_seq_len=max_prompt_len,
  251. start_loc=start_loc_tensor,
  252. max_context_len=None,
  253. context_lens=context_lens_tensor,
  254. block_tables=block_tables,
  255. use_cuda_graph=False,
  256. kv_cache_dtype=self.kv_cache_dtype,
  257. kv_quant_params=self.kv_quant_params,
  258. )
  259. return (input_tokens, input_positions, input_metadata, prompt_lens,
  260. subquery_lens, lora_index_mapping, lora_prompt_mapping,
  261. lora_requests)
  262. def _prepare_decode(
  263. self,
  264. seq_group_metadata_list: List[SequenceGroupMetadata],
  265. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
  266. Set[LoRARequest]]:
  267. assert len(seq_group_metadata_list) > 0
  268. input_tokens: List[List[int]] = []
  269. input_positions: List[List[int]] = []
  270. slot_mapping: List[List[int]] = []
  271. context_lens: List[int] = []
  272. block_tables: List[List[int]] = []
  273. lora_index_mapping: List[int] = []
  274. lora_prompt_mapping: List[int] = []
  275. lora_requests: Set[LoRARequest] = set()
  276. for seq_group_metadata in seq_group_metadata_list:
  277. assert not seq_group_metadata.is_prompt
  278. seq_ids = list(seq_group_metadata.seq_data.keys())
  279. lora_id = seq_group_metadata.lora_int_id
  280. if lora_id > 0:
  281. lora_requests.add(seq_group_metadata.lora_request)
  282. for seq_id in seq_ids:
  283. seq_data = seq_group_metadata.seq_data[seq_id]
  284. generation_token = seq_data.get_last_token_id()
  285. input_tokens.append([generation_token])
  286. seq_len = seq_data.get_len()
  287. position = seq_len - 1
  288. input_positions.append([position])
  289. context_len = seq_len if self.sliding_window is None else min(
  290. seq_len, self.sliding_window)
  291. context_lens.append(context_len)
  292. block_table = seq_group_metadata.block_tables[seq_id]
  293. block_number = block_table[position // self.block_size]
  294. block_offset = position % self.block_size
  295. slot = block_number * self.block_size + block_offset
  296. slot_mapping.append([slot])
  297. lora_index_mapping.append([lora_id])
  298. lora_prompt_mapping.append(lora_id)
  299. if self.sliding_window is not None:
  300. sliding_window_blocks = (self.sliding_window //
  301. self.block_size)
  302. block_table = block_table[-sliding_window_blocks:]
  303. block_tables.append(block_table)
  304. batch_size = len(input_tokens)
  305. max_context_len = max(context_lens)
  306. use_captured_graph = (
  307. not self.model_config.enforce_eager
  308. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  309. and max_context_len <= self.max_context_len_to_capture)
  310. if use_captured_graph:
  311. # Pad the input tokens, positions, and slot mapping to match the
  312. # batch size of the captured graph.
  313. graph_batch_size = _get_graph_batch_size(batch_size)
  314. assert graph_batch_size >= batch_size
  315. for _ in range(graph_batch_size - batch_size):
  316. input_tokens.append([])
  317. input_positions.append([])
  318. slot_mapping.append([])
  319. context_lens.append(1)
  320. block_tables.append([])
  321. batch_size = graph_batch_size
  322. input_tokens = _make_tensor_with_pad(input_tokens,
  323. max_len=1,
  324. pad=0,
  325. dtype=torch.long,
  326. device=self.device)
  327. input_positions = _make_tensor_with_pad(input_positions,
  328. max_len=1,
  329. pad=0,
  330. dtype=torch.long,
  331. device=self.device)
  332. slot_mapping = _make_tensor_with_pad(slot_mapping,
  333. max_len=1,
  334. pad=_PAD_SLOT_ID,
  335. dtype=torch.long,
  336. device=self.device)
  337. context_lens = torch.tensor(context_lens,
  338. dtype=torch.int,
  339. device=self.device)
  340. if use_captured_graph:
  341. # The shape of graph_block_tables is
  342. # [max batch size, max context len // block size].
  343. input_block_tables = self.graph_block_tables[:batch_size]
  344. for i, block_table in enumerate(block_tables):
  345. if block_table:
  346. input_block_tables[i, :len(block_table)] = block_table
  347. block_tables = torch.tensor(input_block_tables, device=self.device)
  348. else:
  349. max_block_table_len = max(
  350. len(block_table) for block_table in block_tables)
  351. block_tables = _make_tensor_with_pad(
  352. block_tables,
  353. max_len=max_block_table_len,
  354. pad=0,
  355. dtype=torch.int,
  356. device=self.device,
  357. )
  358. lora_index_mapping = [
  359. _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
  360. ]
  361. input_metadata = InputMetadata(
  362. is_prompt=False,
  363. slot_mapping=slot_mapping,
  364. prompt_lens=None,
  365. max_seq_len=None,
  366. start_loc=None,
  367. max_context_len=max_context_len,
  368. context_lens=context_lens,
  369. block_tables=block_tables,
  370. use_cuda_graph=use_captured_graph,
  371. kv_cache_dtype=self.kv_cache_dtype,
  372. kv_quant_params=self.kv_quant_params,
  373. )
  374. return (input_tokens, input_positions, input_metadata,
  375. lora_index_mapping, lora_prompt_mapping, lora_requests)
  376. def _prepare_sample(
  377. self,
  378. seq_group_metadata_list: List[SequenceGroupMetadata],
  379. prompt_lens: List[int],
  380. subquery_lens: Optional[List[int]],
  381. ) -> SamplingMetadata:
  382. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  383. selected_token_indices: List[int] = []
  384. generators: List[torch.Generator] = []
  385. selected_token_start_idx = 0
  386. categorized_sample_indices = {t: [] for t in SamplingType}
  387. categorized_sample_indices_start_idx = 0
  388. max_subquery_len = max(subquery_lens) if subquery_lens else 1
  389. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  390. seq_ids = list(seq_group_metadata.seq_data.keys())
  391. sampling_params = seq_group_metadata.sampling_params
  392. seq_groups.append((seq_ids, sampling_params))
  393. if seq_group_metadata.is_prompt:
  394. assert len(seq_ids) == 1
  395. assert subquery_lens is not None
  396. subquery_len = subquery_lens[i]
  397. if sampling_params.prompt_logprobs is not None:
  398. # NOTE: prompt token positions do not need sample, skip
  399. categorized_sample_indices_start_idx += subquery_len - 1
  400. categorized_sample_indices[
  401. sampling_params.sampling_type].append(
  402. categorized_sample_indices_start_idx)
  403. categorized_sample_indices_start_idx += 1
  404. if sampling_params.prompt_logprobs is not None:
  405. selected_token_indices.extend(
  406. range(selected_token_start_idx,
  407. selected_token_start_idx + subquery_len - 1))
  408. selected_token_indices.append(selected_token_start_idx +
  409. subquery_len - 1)
  410. selected_token_start_idx += max_subquery_len
  411. if sampling_params.seed is not None:
  412. seq_group_metadata.state.generator = torch.Generator(
  413. device="cuda").manual_seed(sampling_params.seed)
  414. else:
  415. num_seqs = len(seq_ids)
  416. selected_token_indices.extend(
  417. range(selected_token_start_idx,
  418. selected_token_start_idx + num_seqs))
  419. selected_token_start_idx += num_seqs
  420. categorized_sample_indices[
  421. sampling_params.sampling_type].extend(
  422. range(categorized_sample_indices_start_idx,
  423. categorized_sample_indices_start_idx + num_seqs))
  424. categorized_sample_indices_start_idx += num_seqs
  425. if sampling_params.seed is not None:
  426. generators.append(seq_group_metadata.state.generator)
  427. selected_token_indices = _async_h2d(selected_token_indices,
  428. dtype=torch.long,
  429. target_device=self.device,
  430. pin_memory=not self.in_wsl)
  431. categorized_sample_indices = {
  432. t: _async_h2d(seq_ids,
  433. dtype=torch.int,
  434. target_device=self.device,
  435. pin_memory=not self.in_wsl)
  436. for t, seq_ids in categorized_sample_indices.items()
  437. }
  438. seq_data: Dict[int, SequenceData] = {}
  439. for seq_group_metadata in seq_group_metadata_list:
  440. seq_data.update(seq_group_metadata.seq_data)
  441. seq_persistence_data: Dict[int, dict] = {}
  442. for grp in seq_group_metadata_list:
  443. seq_persistence_data.update(grp.persistent_data)
  444. sampling_metadata = SamplingMetadata(
  445. seq_groups=seq_groups,
  446. seq_data=seq_data,
  447. prompt_lens=prompt_lens,
  448. selected_token_indices=selected_token_indices,
  449. categorized_sample_indices=categorized_sample_indices,
  450. generators=generators,
  451. persistent_metadata=PersistentMetadata(seq_persistence_data),
  452. )
  453. return sampling_metadata
  454. def prepare_input_tensors(
  455. self,
  456. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  457. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
  458. Set[int], LoRAMapping]:
  459. if self.is_driver_worker:
  460. # NOTE: We assume that all sequences in the group are all prompts or
  461. # all decodes.
  462. is_prompt = seq_group_metadata_list[0].is_prompt
  463. # Prepare input tensors.
  464. if is_prompt:
  465. (input_tokens, input_positions, input_metadata, prompt_lens,
  466. subquery_lens, lora_index_mapping, lora_prompt_mapping,
  467. lora_requests) = self._prepare_prompt(seq_group_metadata_list)
  468. else:
  469. (input_tokens, input_positions, input_metadata,
  470. lora_index_mapping, lora_prompt_mapping,
  471. lora_requests) = self._prepare_decode(seq_group_metadata_list)
  472. prompt_lens = []
  473. subquery_lens = None
  474. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  475. prompt_lens,
  476. subquery_lens)
  477. if self.lora_config:
  478. flat_lora_index_mapping = [
  479. item for sublist in lora_index_mapping for item in sublist
  480. ]
  481. lora_mapping = LoRAMapping(
  482. flat_lora_index_mapping,
  483. lora_prompt_mapping,
  484. )
  485. else:
  486. lora_mapping = None
  487. # Broadcast the metadata.
  488. metadata_dict = {
  489. "input_tokens": input_tokens,
  490. "input_positions": input_positions,
  491. "is_prompt": input_metadata.is_prompt,
  492. "slot_mapping": input_metadata.slot_mapping,
  493. "prompt_lens": input_metadata.prompt_lens,
  494. "max_seq_len": input_metadata.max_seq_len,
  495. "start_loc": input_metadata.start_loc,
  496. "max_context_len": input_metadata.max_context_len,
  497. "context_lens": input_metadata.context_lens,
  498. "block_tables": input_metadata.block_tables,
  499. "use_cuda_graph": input_metadata.use_cuda_graph,
  500. "kv_cache_dtype": input_metadata.kv_cache_dtype,
  501. "kv_quant_params": input_metadata.kv_quant_params,
  502. "selected_token_indices":
  503. sampling_metadata.selected_token_indices,
  504. "lora_requests": lora_requests,
  505. "lora_mapping": lora_mapping,
  506. }
  507. broadcast_tensor_dict(metadata_dict, src=0)
  508. else:
  509. metadata_dict = broadcast_tensor_dict(src=0)
  510. input_tokens = metadata_dict["input_tokens"]
  511. input_positions = metadata_dict["input_positions"]
  512. lora_mapping = metadata_dict["lora_mapping"]
  513. lora_requests = metadata_dict["lora_requests"]
  514. input_metadata = InputMetadata(
  515. is_prompt=metadata_dict["is_prompt"],
  516. slot_mapping=metadata_dict["slot_mapping"],
  517. prompt_lens=metadata_dict["prompt_lens"],
  518. max_seq_len=metadata_dict["max_seq_len"],
  519. start_loc=metadata_dict["start_loc"],
  520. max_context_len=metadata_dict["max_context_len"],
  521. context_lens=metadata_dict["context_lens"],
  522. block_tables=metadata_dict["block_tables"],
  523. use_cuda_graph=metadata_dict["use_cuda_graph"],
  524. kv_cache_dtype=metadata_dict["kv_cache_dtype"],
  525. kv_quant_params=metadata_dict["kv_quant_params"],
  526. )
  527. sampling_metadata = SamplingMetadata(
  528. seq_groups=None,
  529. seq_data=None,
  530. prompt_lens=None,
  531. selected_token_indices=metadata_dict["selected_token_indices"],
  532. categorized_sample_indices=None,
  533. generators=None,
  534. perform_sampling=False,
  535. )
  536. return (input_tokens, input_positions, input_metadata,
  537. sampling_metadata, lora_requests, lora_mapping)
  538. @torch.inference_mode()
  539. def execute_model(
  540. self,
  541. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  542. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  543. ) -> Optional[SamplerOutput]:
  544. (input_tokens, input_positions, input_metadata, sampling_metadata,
  545. lora_requests,
  546. lora_mapping) = (self.prepare_input_tensors(seq_group_metadata_list))
  547. if self.lora_config:
  548. self.set_active_loras(lora_requests, lora_mapping)
  549. # Execute the model.
  550. if input_metadata.use_cuda_graph:
  551. graph_batch_size = input_tokens.shape[0]
  552. model_executable = self.graph_runners[graph_batch_size]
  553. else:
  554. model_executable = self.model
  555. hidden_states = model_executable(
  556. input_ids=input_tokens,
  557. positions=input_positions,
  558. kv_caches=kv_caches,
  559. input_metadata=input_metadata,
  560. )
  561. # Sample the next token.
  562. output = self.model.sample(
  563. hidden_states=hidden_states,
  564. sampling_metadata=sampling_metadata,
  565. )
  566. return output
  567. @torch.inference_mode()
  568. def profile_run(self) -> None:
  569. # Enable top-k sampling to reflect the accurate memory usage.
  570. vocab_size = self.model_config.get_vocab_size()
  571. sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
  572. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  573. max_num_seqs = self.scheduler_config.max_num_seqs
  574. # This represents the maximum number of different requests
  575. # that will have unique loras, an therefore the max amount of memory
  576. # consumption create dummy lora request copies from the lora request
  577. # passed in, which contains a lora from the lora warmup path.
  578. dummy_lora_requests = []
  579. dummy_lora_requests_per_seq = []
  580. if self.lora_config:
  581. for idx in range(self.lora_config.max_loras):
  582. lora_id = idx + 1
  583. dummy_lora_request = LoRARequest(
  584. lora_name=f"warmup_{lora_id}",
  585. lora_int_id=lora_id,
  586. lora_local_path="/not/a/real/path",
  587. )
  588. self.lora_manager.add_dummy_lora(dummy_lora_request,
  589. rank=LORA_WARMUP_RANK)
  590. dummy_lora_requests.append(dummy_lora_request)
  591. dummy_lora_requests_per_seq = [
  592. dummy_lora_requests[idx % len(dummy_lora_requests)]
  593. for idx in range(max_num_seqs)
  594. ]
  595. # Profile memory usage with max_num_sequences sequences and the total
  596. # number of tokens equal to max_num_batched_tokens.
  597. seqs: List[SequenceGroupMetadata] = []
  598. for group_id in range(max_num_seqs):
  599. seq_len = (max_num_batched_tokens // max_num_seqs +
  600. (group_id < max_num_batched_tokens % max_num_seqs))
  601. seq_data = SequenceData([0] * seq_len)
  602. seq = SequenceGroupMetadata(
  603. request_id=str(group_id),
  604. is_prompt=True,
  605. seq_data={group_id: seq_data},
  606. sampling_params=sampling_params,
  607. block_tables=None,
  608. persistent_data={},
  609. lora_request=dummy_lora_requests_per_seq[group_id]
  610. if dummy_lora_requests_per_seq else None,
  611. )
  612. seqs.append(seq)
  613. # Run the model with the dummy inputs.
  614. num_layers = self.model_config.get_num_layers(self.parallel_config)
  615. kv_caches = [(None, None)] * num_layers
  616. self.execute_model(seqs, kv_caches)
  617. torch.cuda.synchronize()
  618. return
  619. def remove_all_loras(self) -> bool:
  620. if not self.lora_manager:
  621. raise RuntimeError("LoRA is not enabled.")
  622. return self.lora_manager.remove_all_loras()
  623. def set_active_loras(self, lora_requests: List[LoRARequest],
  624. lora_mapping: LoRAMapping) -> None:
  625. if not self.lora_manager:
  626. raise RuntimeError("LoRA is not enabled.")
  627. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  628. def add_lora(self, lora_request: LoRARequest) -> bool:
  629. if not self.lora_manager:
  630. raise RuntimeError("LoRA is not enabled.")
  631. return self.lora_manager.add_lora(lora_request)
  632. def remove_lora(self, lora_id: int) -> bool:
  633. if not self.lora_manager:
  634. raise RuntimeError("LoRA is not enabled.")
  635. return self.lora_manager.remove_lora(lora_id)
  636. def list_loras(self) -> Set[int]:
  637. if not self.lora_manager:
  638. raise RuntimeError("LoRA is not enabled.")
  639. return self.lora_manager.list_loras()
  640. @torch.inference_mode()
  641. def capture_model(self, kv_caches: List[KVCache]) -> None:
  642. # NOTE: This is a hack to ensure that the NCCL backend is never
  643. # deleted before the CUDA graph
  644. self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
  645. assert not self.model_config.enforce_eager
  646. logger.info("Capturing the model for CUDA graphs. This may lead to "
  647. "unexpected consequences if the model is not static. To "
  648. "run the model in eager mode, set 'enforce_eager=True' or "
  649. "use '--enforce-eager' in the CLI.")
  650. logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
  651. "per GPU. If you are running out of memory, consider "
  652. "decreasing `gpu_memory_utilization` or enforcing "
  653. "eager mode.")
  654. start_time = time.perf_counter()
  655. # Prepare dummy inputs. These will be reused for all batch sizes.
  656. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  657. input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
  658. input_positions = torch.zeros(max_batch_size, 1,
  659. dtype=torch.long).cuda()
  660. slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
  661. slot_mapping.fill_(_PAD_SLOT_ID)
  662. context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  663. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  664. graph_batch_size = _get_graph_batch_size(
  665. self.scheduler_config.max_num_seqs)
  666. batch_size_capture_list = [
  667. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  668. ]
  669. # NOTE: Capturing the largest batch size first may help reduce the
  670. # memory usage of CUDA graph.
  671. # NOTE: There are 3 backends for all-reduce: custom all-reduce
  672. # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
  673. # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
  674. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
  675. # We always prioritize using custom all-reduce kernel but fall back
  676. # to PyTorch or CuPy NCCL if it is disabled or not supported.
  677. # Initialize a new progress bar
  678. progress = get_loading_progress_bar()
  679. task = progress.add_task("[cyan]Capturing graph...",
  680. total=len(batch_size_capture_list))
  681. with progress:
  682. with custom_all_reduce.capture():
  683. for batch_size in reversed(batch_size_capture_list):
  684. if batch_size > self.scheduler_config.max_num_seqs:
  685. continue
  686. # Create dummy input_metadata.
  687. input_metadata = InputMetadata(
  688. is_prompt=False,
  689. slot_mapping=slot_mapping[:batch_size],
  690. prompt_lens=None,
  691. max_seq_len=None,
  692. start_loc=None,
  693. max_context_len=self.max_context_len_to_capture,
  694. context_lens=context_lens[:batch_size],
  695. block_tables=block_tables[:batch_size],
  696. use_cuda_graph=True,
  697. kv_cache_dtype=self.kv_cache_dtype,
  698. kv_quant_params=self.kv_quant_params,
  699. )
  700. if self.lora_config:
  701. lora_mapping = LoRAMapping(
  702. [0] * batch_size,
  703. [0] * batch_size,
  704. )
  705. self.set_active_loras(set(), lora_mapping)
  706. graph_runner = CUDAGraphRunner(self.model)
  707. graph_runner.capture(
  708. input_tokens[:batch_size],
  709. input_positions[:batch_size],
  710. kv_caches,
  711. input_metadata,
  712. memory_pool=self.graph_memory_pool,
  713. )
  714. self.graph_memory_pool = graph_runner.graph.pool()
  715. self.graph_runners[batch_size] = graph_runner
  716. # Update the progress bar
  717. progress.update(task, advance=1)
  718. end_time = time.perf_counter()
  719. elapsed_time = end_time - start_time
  720. # This usually takes < 10 seconds.
  721. logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
  722. def __del__(self) -> None:
  723. # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
  724. # NOTE: This is necessary because otherwise deadlocks can
  725. # happen.
  726. # FIXME: This is a bit hacky. Find a more robust solution.
  727. self.graph_runners.clear()
  728. self.cupy_nccl_backend = None
  729. class CUDAGraphRunner:
  730. def __init__(self, model: nn.Module):
  731. self.model = model
  732. self.graph = None
  733. self.input_buffers: Dict[str, torch.Tensor] = {}
  734. self.output_buffers: Dict[str, torch.Tensor] = {}
  735. def capture(
  736. self,
  737. input_ids: torch.Tensor,
  738. positions: torch.Tensor,
  739. kv_caches: List[KVCache],
  740. input_metadata: InputMetadata,
  741. memory_pool,
  742. ) -> None:
  743. assert self.graph is None
  744. # Run the model once without capturing the graph.
  745. # This is to make sure that the captured graph does not include the
  746. # kernel launches for initial benchmarking (e.g., Triton autotune).
  747. with _maybe_cupy_nccl():
  748. self.model(
  749. input_ids,
  750. positions,
  751. kv_caches,
  752. input_metadata,
  753. )
  754. torch.cuda.synchronize()
  755. # Capture the graph.
  756. # NOTE: Python 3.8 does not support multi-line with statements.
  757. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
  758. self.graph = torch.cuda.CUDAGraph()
  759. with torch.cuda.graph(self.graph, pool=memory_pool):
  760. with _maybe_cupy_nccl():
  761. hidden_states = self.model(
  762. input_ids,
  763. positions,
  764. kv_caches,
  765. input_metadata,
  766. )
  767. torch.cuda.synchronize()
  768. # Save the input and output buffers.
  769. self.input_buffers = {
  770. "input_ids": input_ids,
  771. "positions": positions,
  772. "kv_caches": kv_caches,
  773. "slot_mapping": input_metadata.slot_mapping,
  774. "context_lens": input_metadata.context_lens,
  775. "block_tables": input_metadata.block_tables,
  776. }
  777. self.output_buffers = {"hidden_states": hidden_states}
  778. return
  779. def forward(
  780. self,
  781. input_ids: torch.Tensor,
  782. positions: torch.Tensor,
  783. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  784. input_metadata: InputMetadata,
  785. ) -> torch.Tensor:
  786. # KV caches are fixed tensors, so we don't need to copy them.
  787. del kv_caches
  788. # Copy the input tensors to the input buffers.
  789. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  790. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  791. self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
  792. non_blocking=True)
  793. self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
  794. non_blocking=True)
  795. self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
  796. non_blocking=True)
  797. # Run the graph.
  798. self.graph.replay()
  799. # Return the output tensor.
  800. return self.output_buffers["hidden_states"]
  801. def __call__(self, *args, **kwargs):
  802. return self.forward(*args, **kwargs)
  803. @contextlib.contextmanager
  804. def _maybe_cupy_nccl():
  805. if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
  806. with with_cupy_nccl_for_all_reduce():
  807. yield
  808. else:
  809. yield
  810. def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
  811. assert len(x) <= max_len
  812. return x + [pad] * (max_len - len(x))
  813. def _make_tensor_with_pad(
  814. x: List[List[int]],
  815. max_len: int,
  816. pad: int,
  817. dtype: torch.dtype,
  818. device: Optional[Union[str, torch.device]],
  819. ) -> torch.Tensor:
  820. padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
  821. return torch.tensor(padded_x, dtype=dtype, device=device)
  822. def _get_graph_batch_size(batch_size: int) -> int:
  823. if batch_size <= 2:
  824. return batch_size
  825. elif batch_size <= 4:
  826. return 4
  827. else:
  828. return (batch_size + 7) // 8 * 8
  829. def _async_h2d(
  830. data: list,
  831. dtype: torch.dtype,
  832. target_device: Union[str, torch.device],
  833. pin_memory: bool,
  834. ) -> torch.Tensor:
  835. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  836. return t.to(device=target_device, non_blocking=True)