model_runner.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. import contextlib
  2. import time
  3. from typing import Dict, List, Optional, Tuple, Set, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.common.config import (DeviceConfig, ModelConfig, LoRAConfig,
  8. ParallelConfig, SchedulerConfig)
  9. from aphrodite.common.logger import init_logger
  10. from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
  11. from aphrodite.modeling.megatron import cupy_utils
  12. from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
  13. )
  14. from aphrodite.modeling.megatron.parallel_state import (
  15. with_cupy_nccl_for_all_reduce)
  16. from aphrodite.modeling.megatron import custom_all_reduce
  17. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  18. from aphrodite.common.sequence import (SamplerOutput, SequenceData,
  19. SequenceGroupMetadata)
  20. from aphrodite.modeling.sampling_metadata import PersistentMetadata
  21. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  22. from aphrodite.lora.layers import LoRAMapping
  23. from aphrodite.lora.request import LoRARequest
  24. from aphrodite.common.utils import in_wsl
  25. logger = init_logger(__name__)
  26. KVCache = Tuple[torch.Tensor, torch.Tensor]
  27. _PAD_SLOT_ID = -1
  28. LORA_WARMUP_RANK = 8
  29. # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  30. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  31. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
  32. class ModelRunner:
  33. def __init__(
  34. self,
  35. model_config: ModelConfig,
  36. parallel_config: ParallelConfig,
  37. scheduler_config: SchedulerConfig,
  38. device_config: DeviceConfig,
  39. lora_config: Optional[LoRAConfig],
  40. kv_cache_dtype: Optional[str] = "auto",
  41. is_driver_worker: bool = False,
  42. ):
  43. self.model_config = model_config
  44. self.parallel_config = parallel_config
  45. self.scheduler_config = scheduler_config
  46. self.lora_config = lora_config
  47. self.is_driver_worker = is_driver_worker
  48. # model_config can be None in tests/samplers/test_sampler.py.
  49. # FIXME: This is a hack to make the tests work. Refactor this.
  50. self.sliding_window = (model_config.get_sliding_window()
  51. if model_config is not None else None)
  52. self.device_config = (device_config
  53. if device_config is not None else DeviceConfig())
  54. self.device = self.device_config.device
  55. self.model = None
  56. self.block_size = None # Set after initial profiling.
  57. self.lora_manager = None
  58. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  59. self.graph_memory_pool = None # Set during graph capture.
  60. self.max_context_len_to_capture = (
  61. self.model_config.max_context_len_to_capture
  62. if self.model_config is not None else 0)
  63. # When using CUDA graph, the input block tables must be padded to
  64. # max_context_len_to_capture. However, creating the block table in
  65. # Python can be expensive. To optimize this, we cache the block table
  66. # in numpy and only copy the actual input content at every iteration.
  67. # The shape of the cached block table will be
  68. # (max batch size to capture, max context len to capture / block size).
  69. self.graph_block_tables = None # Set after initial profiling.
  70. # cache in_wsl result
  71. self.in_wsl = in_wsl()
  72. self.kv_cache_dtype = kv_cache_dtype
  73. def load_model(self) -> None:
  74. self.model = get_model(self.model_config, self.device_config,
  75. self.lora_config)
  76. vocab_size = self.model.config.vocab_size
  77. if self.lora_config:
  78. assert hasattr(
  79. self.model, "supported_lora_modules"
  80. ) and self.model.supported_lora_modules, "Model does not support LoRA"
  81. assert hasattr(
  82. self.model,
  83. "embedding_modules"), "Model does not have embedding_modules"
  84. assert hasattr(self.model, "embedding_padding_modules"
  85. ), "Model does not have embedding_padding_modules"
  86. self.lora_manager = LRUCacheWorkerLoRAManager(
  87. self.scheduler_config.max_num_seqs,
  88. self.scheduler_config.max_num_batched_tokens +
  89. self.scheduler_config.max_paddings, vocab_size,
  90. self.lora_config, self.device, self.model.embedding_modules,
  91. self.model.embedding_padding_modules)
  92. self.model = self.lora_manager.create_lora_manager(self.model)
  93. def set_block_size(self, block_size: int) -> None:
  94. self.block_size = block_size
  95. max_num_blocks = (self.max_context_len_to_capture + block_size -
  96. 1) // block_size
  97. self.graph_block_tables = np.zeros(
  98. (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
  99. def _prepare_prompt(
  100. self,
  101. seq_group_metadata_list: List[SequenceGroupMetadata],
  102. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
  103. List[int], List[int], Set[LoRARequest]]:
  104. assert len(seq_group_metadata_list) > 0
  105. input_tokens: List[List[int]] = []
  106. input_positions: List[List[int]] = []
  107. slot_mapping: List[List[int]] = []
  108. lora_index_mapping: List[int] = []
  109. lora_prompt_mapping: List[int] = []
  110. lora_requests: Set[LoRARequest] = set()
  111. prompt_lens: List[int] = []
  112. context_lens: List[int] = []
  113. subquery_lens: List[int] = []
  114. prefix_block_tables: List[List[int]] = []
  115. for seq_group_metadata in seq_group_metadata_list:
  116. assert seq_group_metadata.is_prompt
  117. seq_ids = list(seq_group_metadata.seq_data.keys())
  118. assert len(seq_ids) == 1
  119. seq_id = seq_ids[0]
  120. seq_data = seq_group_metadata.seq_data[seq_id]
  121. prompt_tokens = seq_data.get_token_ids()
  122. prompt_len = len(prompt_tokens)
  123. prompt_lens.append(prompt_len)
  124. prefix_len = 0
  125. prefix = seq_group_metadata.prefix
  126. if prefix is not None and prefix.computed:
  127. prefix_len = prefix.get_length()
  128. prompt_tokens = prompt_tokens[prefix_len:]
  129. prefix_block_tables.append(prefix.get_block_numbers())
  130. else:
  131. prefix_block_tables.append([])
  132. # actual prompt lens
  133. context_lens.append(prefix_len)
  134. subquery_lens.append(prompt_len - prefix_len)
  135. input_tokens.append(prompt_tokens)
  136. # NOTE: Here we assume that the first token in the prompt
  137. # is always the first token in the sequence.
  138. input_positions.append(
  139. list(range(prefix_len, prefix_len + len(prompt_tokens))))
  140. lora_id = seq_group_metadata.lora_int_id
  141. if lora_id > 0:
  142. lora_requests.add(seq_group_metadata.lora_request)
  143. lora_index_mapping.append([lora_id] * (prompt_len - prefix_len))
  144. lora_prompt_mapping.extend(
  145. [lora_id] *
  146. (prompt_len - prefix_len
  147. if seq_group_metadata.sampling_params.prompt_logprobs else 1))
  148. if seq_group_metadata.block_tables is None:
  149. # During memory profiling, the block tables are not initialized
  150. # yet. In this case, we just use a dummy slot mapping.
  151. slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
  152. continue
  153. # Compute the slot mapping.
  154. slot_mapping.append([])
  155. block_table = seq_group_metadata.block_tables[seq_id]
  156. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  157. # where start_idx is max(0, prompt_len - sliding_window).
  158. # For example, if the prompt len is 10, sliding window is 8, and
  159. # block size is 4, the first two tokens are masked and the slot
  160. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  161. start_idx = 0
  162. if self.sliding_window is not None:
  163. assert prefix_len == 0, (
  164. "Prefix caching is currently not supported with "
  165. "sliding window attention")
  166. start_idx = max(0, prompt_len - self.sliding_window)
  167. for i in range(prefix_len, prompt_len):
  168. if i < start_idx:
  169. slot_mapping[-1].append(_PAD_SLOT_ID)
  170. continue
  171. block_number = block_table[i // self.block_size]
  172. block_offset = i % self.block_size
  173. slot = block_number * self.block_size + block_offset
  174. slot_mapping[-1].append(slot)
  175. max_prompt_len = max(subquery_lens)
  176. input_tokens = _make_tensor_with_pad(input_tokens,
  177. max_prompt_len,
  178. pad=0,
  179. dtype=torch.long,
  180. device=self.device)
  181. input_positions = _make_tensor_with_pad(input_positions,
  182. max_prompt_len,
  183. pad=0,
  184. dtype=torch.long,
  185. device=self.device)
  186. slot_mapping = _make_tensor_with_pad(slot_mapping,
  187. max_prompt_len,
  188. pad=_PAD_SLOT_ID,
  189. dtype=torch.long,
  190. device=self.device)
  191. lora_index_mapping = [
  192. _pad_to_max(mapping, max_prompt_len, pad=0)
  193. for mapping in lora_index_mapping
  194. ]
  195. context_lens_tensor = torch.tensor(context_lens,
  196. dtype=torch.int,
  197. device=self.device)
  198. # Prepare prefix block tables
  199. max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
  200. block_tables = _make_tensor_with_pad(
  201. prefix_block_tables,
  202. max_len=max_prompt_block_table_len,
  203. pad=0,
  204. dtype=torch.int,
  205. device=self.device,
  206. )
  207. start_loc_tensor = torch.arange(0,
  208. len(prompt_lens) * max_prompt_len,
  209. max_prompt_len,
  210. dtype=torch.long,
  211. device=self.device)
  212. prompt_lens_tensor = torch.tensor(prompt_lens,
  213. dtype=torch.long,
  214. device=self.device)
  215. input_metadata = InputMetadata(
  216. is_prompt=True,
  217. slot_mapping=slot_mapping,
  218. prompt_lens=prompt_lens_tensor,
  219. max_seq_len=max_prompt_len,
  220. start_loc=start_loc_tensor,
  221. max_context_len=None,
  222. context_lens=context_lens_tensor,
  223. block_tables=block_tables,
  224. use_cuda_graph=False,
  225. kv_cache_dtype=self.kv_cache_dtype,
  226. )
  227. return (input_tokens, input_positions, input_metadata, prompt_lens,
  228. subquery_lens, lora_index_mapping, lora_prompt_mapping,
  229. lora_requests)
  230. def _prepare_decode(
  231. self,
  232. seq_group_metadata_list: List[SequenceGroupMetadata],
  233. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
  234. Set[LoRARequest]]:
  235. assert len(seq_group_metadata_list) > 0
  236. input_tokens: List[List[int]] = []
  237. input_positions: List[List[int]] = []
  238. slot_mapping: List[List[int]] = []
  239. context_lens: List[int] = []
  240. block_tables: List[List[int]] = []
  241. lora_index_mapping: List[int] = []
  242. lora_prompt_mapping: List[int] = []
  243. lora_requests: Set[LoRARequest] = set()
  244. for seq_group_metadata in seq_group_metadata_list:
  245. assert not seq_group_metadata.is_prompt
  246. seq_ids = list(seq_group_metadata.seq_data.keys())
  247. lora_id = seq_group_metadata.lora_int_id
  248. if lora_id > 0:
  249. lora_requests.add(seq_group_metadata.lora_request)
  250. for seq_id in seq_ids:
  251. seq_data = seq_group_metadata.seq_data[seq_id]
  252. generation_token = seq_data.get_last_token_id()
  253. input_tokens.append([generation_token])
  254. seq_len = seq_data.get_len()
  255. position = seq_len - 1
  256. input_positions.append([position])
  257. context_len = seq_len if self.sliding_window is None else min(
  258. seq_len, self.sliding_window)
  259. context_lens.append(context_len)
  260. block_table = seq_group_metadata.block_tables[seq_id]
  261. block_number = block_table[position // self.block_size]
  262. block_offset = position % self.block_size
  263. slot = block_number * self.block_size + block_offset
  264. slot_mapping.append([slot])
  265. lora_index_mapping.append([lora_id])
  266. lora_prompt_mapping.append(lora_id)
  267. if self.sliding_window is not None:
  268. sliding_window_blocks = (self.sliding_window //
  269. self.block_size)
  270. block_table = block_table[-sliding_window_blocks:]
  271. block_tables.append(block_table)
  272. batch_size = len(input_tokens)
  273. max_context_len = max(context_lens)
  274. use_captured_graph = (
  275. not self.model_config.enforce_eager
  276. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  277. and max_context_len <= self.max_context_len_to_capture)
  278. if use_captured_graph:
  279. # Pad the input tokens, positions, and slot mapping to match the
  280. # batch size of the captured graph.
  281. graph_batch_size = _get_graph_batch_size(batch_size)
  282. assert graph_batch_size >= batch_size
  283. for _ in range(graph_batch_size - batch_size):
  284. input_tokens.append([])
  285. input_positions.append([])
  286. slot_mapping.append([])
  287. context_lens.append(1)
  288. block_tables.append([])
  289. batch_size = graph_batch_size
  290. input_tokens = _make_tensor_with_pad(input_tokens,
  291. max_len=1,
  292. pad=0,
  293. dtype=torch.long,
  294. device=self.device)
  295. input_positions = _make_tensor_with_pad(input_positions,
  296. max_len=1,
  297. pad=0,
  298. dtype=torch.long,
  299. device=self.device)
  300. slot_mapping = _make_tensor_with_pad(slot_mapping,
  301. max_len=1,
  302. pad=_PAD_SLOT_ID,
  303. dtype=torch.long,
  304. device=self.device)
  305. context_lens = torch.tensor(context_lens,
  306. dtype=torch.int,
  307. device=self.device)
  308. if use_captured_graph:
  309. # The shape of graph_block_tables is
  310. # [max batch size, max context len // block size].
  311. input_block_tables = self.graph_block_tables[:batch_size]
  312. for i, block_table in enumerate(block_tables):
  313. if block_table:
  314. input_block_tables[i, :len(block_table)] = block_table
  315. block_tables = torch.tensor(input_block_tables, device=self.device)
  316. else:
  317. max_block_table_len = max(
  318. len(block_table) for block_table in block_tables)
  319. block_tables = _make_tensor_with_pad(
  320. block_tables,
  321. max_len=max_block_table_len,
  322. pad=0,
  323. dtype=torch.int,
  324. device=self.device,
  325. )
  326. lora_index_mapping = [
  327. _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
  328. ]
  329. input_metadata = InputMetadata(
  330. is_prompt=False,
  331. slot_mapping=slot_mapping,
  332. prompt_lens=None,
  333. max_seq_len=None,
  334. start_loc=None,
  335. max_context_len=max_context_len,
  336. context_lens=context_lens,
  337. block_tables=block_tables,
  338. use_cuda_graph=use_captured_graph,
  339. kv_cache_dtype=self.kv_cache_dtype,
  340. )
  341. return (input_tokens, input_positions, input_metadata,
  342. lora_index_mapping, lora_prompt_mapping, lora_requests)
  343. def _prepare_sample(
  344. self,
  345. seq_group_metadata_list: List[SequenceGroupMetadata],
  346. prompt_lens: List[int],
  347. subquery_lens: Optional[List[int]],
  348. ) -> SamplingMetadata:
  349. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  350. selected_token_indices: List[int] = []
  351. selected_token_start_idx = 0
  352. categorized_sample_indices = {t: [] for t in SamplingType}
  353. categorized_sample_indices_start_idx = 0
  354. max_subquery_len = max(subquery_lens) if subquery_lens else 1
  355. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  356. seq_ids = list(seq_group_metadata.seq_data.keys())
  357. sampling_params = seq_group_metadata.sampling_params
  358. seq_groups.append((seq_ids, sampling_params))
  359. if seq_group_metadata.is_prompt:
  360. assert len(seq_ids) == 1
  361. assert subquery_lens is not None
  362. subquery_len = subquery_lens[i]
  363. if sampling_params.prompt_logprobs is not None:
  364. # NOTE: prompt token positions do not need sample, skip
  365. categorized_sample_indices_start_idx += subquery_len - 1
  366. categorized_sample_indices[
  367. sampling_params.sampling_type].append(
  368. categorized_sample_indices_start_idx)
  369. categorized_sample_indices_start_idx += 1
  370. if sampling_params.prompt_logprobs is not None:
  371. selected_token_indices.extend(
  372. range(selected_token_start_idx,
  373. selected_token_start_idx + subquery_len - 1))
  374. selected_token_indices.append(selected_token_start_idx +
  375. subquery_len - 1)
  376. selected_token_start_idx += max_subquery_len
  377. else:
  378. num_seqs = len(seq_ids)
  379. selected_token_indices.extend(
  380. range(selected_token_start_idx,
  381. selected_token_start_idx + num_seqs))
  382. selected_token_start_idx += num_seqs
  383. categorized_sample_indices[
  384. sampling_params.sampling_type].extend(
  385. range(categorized_sample_indices_start_idx,
  386. categorized_sample_indices_start_idx + num_seqs))
  387. categorized_sample_indices_start_idx += num_seqs
  388. selected_token_indices = _async_h2d(selected_token_indices,
  389. dtype=torch.long,
  390. target_device=self.device,
  391. pin_memory=not self.in_wsl)
  392. categorized_sample_indices = {
  393. t: _async_h2d(seq_ids,
  394. dtype=torch.int,
  395. target_device=self.device,
  396. pin_memory=not self.in_wsl)
  397. for t, seq_ids in categorized_sample_indices.items()
  398. }
  399. seq_data: Dict[int, SequenceData] = {}
  400. for seq_group_metadata in seq_group_metadata_list:
  401. seq_data.update(seq_group_metadata.seq_data)
  402. seq_persistence_data: Dict[int, dict] = {}
  403. for grp in seq_group_metadata_list:
  404. seq_persistence_data.update(grp.persistent_data)
  405. sampling_metadata = SamplingMetadata(
  406. seq_groups=seq_groups,
  407. seq_data=seq_data,
  408. prompt_lens=prompt_lens,
  409. selected_token_indices=selected_token_indices,
  410. categorized_sample_indices=categorized_sample_indices,
  411. persistent_metadata=PersistentMetadata(seq_persistence_data),
  412. )
  413. return sampling_metadata
  414. def prepare_input_tensors(
  415. self,
  416. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  417. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
  418. Set[int], LoRAMapping]:
  419. if self.is_driver_worker:
  420. # NOTE: We assume that all sequences in the group are all prompts or
  421. # all decodes.
  422. is_prompt = seq_group_metadata_list[0].is_prompt
  423. # Prepare input tensors.
  424. if is_prompt:
  425. (input_tokens, input_positions, input_metadata, prompt_lens,
  426. subquery_lens, lora_index_mapping, lora_prompt_mapping,
  427. lora_requests) = self._prepare_prompt(seq_group_metadata_list)
  428. else:
  429. (input_tokens, input_positions, input_metadata,
  430. lora_index_mapping, lora_prompt_mapping,
  431. lora_requests) = self._prepare_decode(seq_group_metadata_list)
  432. prompt_lens = []
  433. subquery_lens = None
  434. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  435. prompt_lens,
  436. subquery_lens)
  437. if self.lora_config:
  438. flat_lora_index_mapping = [
  439. item for sublist in lora_index_mapping for item in sublist
  440. ]
  441. lora_mapping = LoRAMapping(
  442. flat_lora_index_mapping,
  443. lora_prompt_mapping,
  444. )
  445. else:
  446. lora_mapping = None
  447. # Broadcast the metadata.
  448. metadata_dict = {
  449. "input_tokens": input_tokens,
  450. "input_positions": input_positions,
  451. "is_prompt": input_metadata.is_prompt,
  452. "slot_mapping": input_metadata.slot_mapping,
  453. "prompt_lens": input_metadata.prompt_lens,
  454. "max_seq_len": input_metadata.max_seq_len,
  455. "start_loc": input_metadata.start_loc,
  456. "max_context_len": input_metadata.max_context_len,
  457. "context_lens": input_metadata.context_lens,
  458. "block_tables": input_metadata.block_tables,
  459. "use_cuda_graph": input_metadata.use_cuda_graph,
  460. "kv_cache_dtype": input_metadata.kv_cache_dtype,
  461. "selected_token_indices":
  462. sampling_metadata.selected_token_indices,
  463. "lora_requests": lora_requests,
  464. "lora_mapping": lora_mapping,
  465. }
  466. broadcast_tensor_dict(metadata_dict, src=0)
  467. else:
  468. metadata_dict = broadcast_tensor_dict(src=0)
  469. input_tokens = metadata_dict["input_tokens"]
  470. input_positions = metadata_dict["input_positions"]
  471. lora_mapping = metadata_dict["lora_mapping"]
  472. lora_requests = metadata_dict["lora_requests"]
  473. input_metadata = InputMetadata(
  474. is_prompt=metadata_dict["is_prompt"],
  475. slot_mapping=metadata_dict["slot_mapping"],
  476. prompt_lens=metadata_dict["prompt_lens"],
  477. max_seq_len=metadata_dict["max_seq_len"],
  478. start_loc=metadata_dict["start_loc"],
  479. max_context_len=metadata_dict["max_context_len"],
  480. context_lens=metadata_dict["context_lens"],
  481. block_tables=metadata_dict["block_tables"],
  482. use_cuda_graph=metadata_dict["use_cuda_graph"],
  483. kv_cache_dtype=metadata_dict["kv_cache_dtype"],
  484. )
  485. sampling_metadata = SamplingMetadata(
  486. seq_groups=None,
  487. seq_data=None,
  488. prompt_lens=None,
  489. selected_token_indices=metadata_dict["selected_token_indices"],
  490. categorized_sample_indices=None,
  491. perform_sampling=False,
  492. )
  493. return (input_tokens, input_positions, input_metadata,
  494. sampling_metadata, lora_requests, lora_mapping)
  495. @torch.inference_mode()
  496. def execute_model(
  497. self,
  498. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  499. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  500. ) -> Optional[SamplerOutput]:
  501. (input_tokens, input_positions, input_metadata, sampling_metadata,
  502. lora_requests,
  503. lora_mapping) = (self.prepare_input_tensors(seq_group_metadata_list))
  504. if self.lora_config:
  505. self.set_active_loras(lora_requests, lora_mapping)
  506. # Execute the model.
  507. if input_metadata.use_cuda_graph:
  508. graph_batch_size = input_tokens.shape[0]
  509. model_executable = self.graph_runners[graph_batch_size]
  510. else:
  511. model_executable = self.model
  512. hidden_states = model_executable(
  513. input_ids=input_tokens,
  514. positions=input_positions,
  515. kv_caches=kv_caches,
  516. input_metadata=input_metadata,
  517. )
  518. # Sample the next token.
  519. output = self.model.sample(
  520. hidden_states=hidden_states,
  521. sampling_metadata=sampling_metadata,
  522. )
  523. return output
  524. @torch.inference_mode()
  525. def profile_run(self) -> None:
  526. # Enable top-k sampling to reflect the accurate memory usage.
  527. vocab_size = self.model_config.get_vocab_size()
  528. sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
  529. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  530. max_num_seqs = self.scheduler_config.max_num_seqs
  531. # This represents the maximum number of different requests
  532. # that will have unique loras, an therefore the max amount of memory
  533. # consumption create dummy lora request copies from the lora request
  534. # passed in, which contains a lora from the lora warmup path.
  535. dummy_lora_requests = []
  536. dummy_lora_requests_per_seq = []
  537. if self.lora_config:
  538. for idx in range(self.lora_config.max_loras):
  539. lora_id = idx + 1
  540. dummy_lora_request = LoRARequest(
  541. lora_name=f"warmup_{lora_id}",
  542. lora_int_id=lora_id,
  543. lora_local_path="/not/a/real/path",
  544. )
  545. self.lora_manager.add_dummy_lora(dummy_lora_request,
  546. rank=LORA_WARMUP_RANK)
  547. dummy_lora_requests.append(dummy_lora_request)
  548. dummy_lora_requests_per_seq = [
  549. dummy_lora_requests[idx % len(dummy_lora_requests)]
  550. for idx in range(max_num_seqs)
  551. ]
  552. # Profile memory usage with max_num_sequences sequences and the total
  553. # number of tokens equal to max_num_batched_tokens.
  554. seqs: List[SequenceGroupMetadata] = []
  555. for group_id in range(max_num_seqs):
  556. seq_len = (max_num_batched_tokens // max_num_seqs +
  557. (group_id < max_num_batched_tokens % max_num_seqs))
  558. seq_data = SequenceData([0] * seq_len)
  559. seq = SequenceGroupMetadata(
  560. request_id=str(group_id),
  561. is_prompt=True,
  562. seq_data={group_id: seq_data},
  563. sampling_params=sampling_params,
  564. block_tables=None,
  565. persistent_data={},
  566. lora_request=dummy_lora_requests_per_seq[group_id]
  567. if dummy_lora_requests_per_seq else None,
  568. )
  569. seqs.append(seq)
  570. # Run the model with the dummy inputs.
  571. num_layers = self.model_config.get_num_layers(self.parallel_config)
  572. kv_caches = [(None, None)] * num_layers
  573. self.execute_model(seqs, kv_caches)
  574. torch.cuda.synchronize()
  575. return
  576. def remove_all_loras(self) -> bool:
  577. if not self.lora_manager:
  578. raise RuntimeError("LoRA is not enabled.")
  579. return self.lora_manager.remove_all_loras()
  580. def set_active_loras(self, lora_requests: List[LoRARequest],
  581. lora_mapping: LoRAMapping) -> None:
  582. if not self.lora_manager:
  583. raise RuntimeError("LoRA is not enabled.")
  584. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  585. def add_lora(self, lora_request: LoRARequest) -> bool:
  586. if not self.lora_manager:
  587. raise RuntimeError("LoRA is not enabled.")
  588. return self.lora_manager.add_lora(lora_request)
  589. def remove_lora(self, lora_id: int) -> bool:
  590. if not self.lora_manager:
  591. raise RuntimeError("LoRA is not enabled.")
  592. return self.lora_manager.remove_lora(lora_id)
  593. def list_loras(self) -> Set[int]:
  594. if not self.lora_manager:
  595. raise RuntimeError("LoRA is not enabled.")
  596. return self.lora_manager.list_loras()
  597. @torch.inference_mode()
  598. def capture_model(self, kv_caches: List[KVCache]) -> None:
  599. # NOTE: This is a hack to ensure that the NCCL backend is never
  600. # deleted before the CUDA graph
  601. self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
  602. assert not self.model_config.enforce_eager
  603. logger.info("Capturing the model for CUDA graphs. This may lead to "
  604. "unexpected consequences if the model is not static. To "
  605. "run the model in eager mode, set 'enforce_eager=True' or "
  606. "use '--enforce-eager' in the CLI.")
  607. logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
  608. "per GPU. If you are running out of memory, consider "
  609. "decreasing `gpu_memory_utilization` or enforcing "
  610. "eager mode.")
  611. start_time = time.perf_counter()
  612. # Prepare dummy inputs. These will be reused for all batch sizes.
  613. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  614. input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
  615. input_positions = torch.zeros(max_batch_size, 1,
  616. dtype=torch.long).cuda()
  617. slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
  618. slot_mapping.fill_(_PAD_SLOT_ID)
  619. context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  620. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  621. graph_batch_size = _get_graph_batch_size(
  622. self.scheduler_config.max_num_seqs)
  623. batch_size_capture_list = [
  624. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  625. ]
  626. # NOTE: Capturing the largest batch size first may help reduce the
  627. # memory usage of CUDA graph.
  628. # NOTE: There are 3 backends for all-reduce: custom all-reduce
  629. # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
  630. # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
  631. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
  632. # We always prioritize using custom all-reduce kernel but fall back
  633. # to PyTorch or CuPy NCCL if it is disabled or not supported.
  634. with custom_all_reduce.capture():
  635. for batch_size in reversed(batch_size_capture_list):
  636. if batch_size > self.scheduler_config.max_num_seqs:
  637. continue
  638. # Create dummy input_metadata.
  639. input_metadata = InputMetadata(
  640. is_prompt=False,
  641. slot_mapping=slot_mapping[:batch_size],
  642. prompt_lens=None,
  643. max_seq_len=None,
  644. start_loc=None,
  645. max_context_len=self.max_context_len_to_capture,
  646. context_lens=context_lens[:batch_size],
  647. block_tables=block_tables[:batch_size],
  648. use_cuda_graph=True,
  649. kv_cache_dtype=self.kv_cache_dtype,
  650. )
  651. if self.lora_config:
  652. lora_mapping = LoRAMapping(
  653. [0] * batch_size,
  654. [0] * batch_size,
  655. )
  656. self.set_active_loras(set(), lora_mapping)
  657. graph_runner = CUDAGraphRunner(self.model)
  658. graph_runner.capture(
  659. input_tokens[:batch_size],
  660. input_positions[:batch_size],
  661. kv_caches,
  662. input_metadata,
  663. memory_pool=self.graph_memory_pool,
  664. )
  665. self.graph_memory_pool = graph_runner.graph.pool()
  666. self.graph_runners[batch_size] = graph_runner
  667. end_time = time.perf_counter()
  668. elapsed_time = end_time - start_time
  669. # This usually takes < 10 seconds.
  670. logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
  671. def __del__(self) -> None:
  672. # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
  673. # NOTE: This is necessary because otherwise deadlocks can
  674. # happen.
  675. # FIXME: This is a bit hacky. Find a more robust solution.
  676. self.graph_runners.clear()
  677. self.cupy_nccl_backend = None
  678. class CUDAGraphRunner:
  679. def __init__(self, model: nn.Module):
  680. self.model = model
  681. self.graph = None
  682. self.input_buffers: Dict[str, torch.Tensor] = {}
  683. self.output_buffers: Dict[str, torch.Tensor] = {}
  684. def capture(
  685. self,
  686. input_ids: torch.Tensor,
  687. positions: torch.Tensor,
  688. kv_caches: List[KVCache],
  689. input_metadata: InputMetadata,
  690. memory_pool,
  691. ) -> None:
  692. assert self.graph is None
  693. # Run the model once without capturing the graph.
  694. # This is to make sure that the captured graph does not include the
  695. # kernel launches for initial benchmarking (e.g., Triton autotune).
  696. with _maybe_cupy_nccl():
  697. self.model(
  698. input_ids,
  699. positions,
  700. kv_caches,
  701. input_metadata,
  702. )
  703. torch.cuda.synchronize()
  704. # Capture the graph.
  705. # NOTE: Python 3.8 does not support multi-line with statements.
  706. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
  707. self.graph = torch.cuda.CUDAGraph()
  708. with torch.cuda.graph(self.graph, pool=memory_pool):
  709. with _maybe_cupy_nccl():
  710. hidden_states = self.model(
  711. input_ids,
  712. positions,
  713. kv_caches,
  714. input_metadata,
  715. )
  716. torch.cuda.synchronize()
  717. # Save the input and output buffers.
  718. self.input_buffers = {
  719. "input_ids": input_ids,
  720. "positions": positions,
  721. "kv_caches": kv_caches,
  722. "slot_mapping": input_metadata.slot_mapping,
  723. "context_lens": input_metadata.context_lens,
  724. "block_tables": input_metadata.block_tables,
  725. }
  726. self.output_buffers = {"hidden_states": hidden_states}
  727. return
  728. def forward(
  729. self,
  730. input_ids: torch.Tensor,
  731. positions: torch.Tensor,
  732. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  733. input_metadata: InputMetadata,
  734. ) -> torch.Tensor:
  735. # KV caches are fixed tensors, so we don't need to copy them.
  736. del kv_caches
  737. # Copy the input tensors to the input buffers.
  738. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  739. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  740. self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
  741. non_blocking=True)
  742. self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
  743. non_blocking=True)
  744. self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
  745. non_blocking=True)
  746. # Run the graph.
  747. self.graph.replay()
  748. # Return the output tensor.
  749. return self.output_buffers["hidden_states"]
  750. def __call__(self, *args, **kwargs):
  751. return self.forward(*args, **kwargs)
  752. @contextlib.contextmanager
  753. def _maybe_cupy_nccl():
  754. if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
  755. with with_cupy_nccl_for_all_reduce():
  756. yield
  757. else:
  758. yield
  759. def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
  760. assert len(x) <= max_len
  761. return x + [pad] * (max_len - len(x))
  762. def _make_tensor_with_pad(
  763. x: List[List[int]],
  764. max_len: int,
  765. pad: int,
  766. dtype: torch.dtype,
  767. device: Optional[Union[str, torch.device]],
  768. ) -> torch.Tensor:
  769. padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
  770. return torch.tensor(padded_x, dtype=dtype, device=device)
  771. def _get_graph_batch_size(batch_size: int) -> int:
  772. if batch_size <= 2:
  773. return batch_size
  774. elif batch_size <= 4:
  775. return 4
  776. else:
  777. return (batch_size + 7) // 8 * 8
  778. def _async_h2d(
  779. data: list,
  780. dtype: torch.dtype,
  781. target_device: Union[str, torch.device],
  782. pin_memory: bool,
  783. ) -> torch.Tensor:
  784. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  785. return t.to(device=target_device, non_blocking=True)