model_runner.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030
  1. import contextlib
  2. import time
  3. from typing import Dict, List, Optional, Tuple, Set, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from loguru import logger
  8. from aphrodite.common.config import (
  9. DeviceConfig,
  10. ModelConfig,
  11. LoRAConfig,
  12. ParallelConfig,
  13. SchedulerConfig,
  14. )
  15. from aphrodite.common.logger import get_loading_progress_bar
  16. from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
  17. from aphrodite.modeling.megatron import cupy_utils
  18. from aphrodite.modeling.megatron.communication_op import broadcast_tensor_dict
  19. from aphrodite.modeling.megatron.parallel_state import (
  20. get_tensor_model_parallel_world_size,
  21. with_cupy_nccl_for_all_reduce,
  22. )
  23. from aphrodite.modeling.megatron import custom_all_reduce
  24. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  25. from aphrodite.common.sequence import (
  26. SamplerOutput,
  27. SequenceData,
  28. SequenceGroupMetadata,
  29. )
  30. from aphrodite.modeling.sampling_metadata import PersistentMetadata
  31. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  32. from aphrodite.lora.layers import LoRAMapping
  33. from aphrodite.lora.request import LoRARequest
  34. from aphrodite.common.utils import in_wsl, measure_cuda_memory
  35. KVCache = Tuple[torch.Tensor, torch.Tensor]
  36. _PAD_SLOT_ID = -1
  37. LORA_WARMUP_RANK = 8
  38. # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  39. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  40. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
  41. class ModelRunner:
  42. def __init__(
  43. self,
  44. model_config: ModelConfig,
  45. parallel_config: ParallelConfig,
  46. scheduler_config: SchedulerConfig,
  47. device_config: DeviceConfig,
  48. lora_config: Optional[LoRAConfig],
  49. kv_cache_dtype: Optional[str] = "auto",
  50. # kv_quant_params_path: Optional[str] = None,
  51. is_driver_worker: bool = False,
  52. ):
  53. self.model_config = model_config
  54. self.parallel_config = parallel_config
  55. self.scheduler_config = scheduler_config
  56. self.lora_config = lora_config
  57. self.is_driver_worker = is_driver_worker
  58. # model_config can be None in tests/samplers/test_sampler.py.
  59. # FIXME: This is a hack to make the tests work. Refactor this.
  60. self.sliding_window = (model_config.get_sliding_window()
  61. if model_config is not None else None)
  62. self.device_config = (device_config
  63. if device_config is not None else DeviceConfig())
  64. self.device = self.device_config.device
  65. self.model = None
  66. self.block_size = None # Set after initial profiling.
  67. self.lora_manager = None
  68. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  69. self.graph_memory_pool = None # Set during graph capture.
  70. self.max_context_len_to_capture = (
  71. self.model_config.max_context_len_to_capture
  72. if self.model_config is not None else 0)
  73. # When using CUDA graph, the input block tables must be padded to
  74. # max_context_len_to_capture. However, creating the block table in
  75. # Python can be expensive. To optimize this, we cache the block table
  76. # in numpy and only copy the actual input content at every iteration.
  77. # The shape of the cached block table will be
  78. # (max batch size to capture, max context len to capture / block size).
  79. self.graph_block_tables = None # Set after initial profiling.
  80. # cache in_wsl result
  81. self.in_wsl = in_wsl()
  82. self.kv_cache_dtype = kv_cache_dtype
  83. # self.kv_quant_params = (
  84. # self.load_kv_quant_params(model_config, kv_quant_params_path)
  85. # if self.kv_cache_dtype == "int8"
  86. # else None
  87. # )
  88. # Set enforce_eager to True for Neuron backend, to avoid capturing graph
  89. if self.device_config.is_neuron:
  90. self.model_config.enforce_eager = True
  91. # def load_kv_quant_params(
  92. # self, model_config: ModelConfig, kv_quant_params_path: str
  93. # ) -> List[List[float]]:
  94. # if model_config is None:
  95. # return None
  96. # # Remove it when all models support kv cache int8.
  97. # architectures = model_config.hf_config.architectures
  98. # for arch in architectures:
  99. # if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
  100. # raise ValueError(
  101. # "KV CACHE INT8 is not supported for model architectures "
  102. # f"{arch} for now. "
  103. # "Supported architectures: LlamaForCausalLM and "
  104. # "LLaMAForCausalLM."
  105. # )
  106. # num_layers = model_config.hf_config.num_hidden_layers
  107. # kv_quant_params = []
  108. # for i in range(num_layers):
  109. # if kv_quant_params_path is not None:
  110. # path = (
  111. # kv_quant_params_path + f"/layers.{i}.past_kv_scale.0.weight" # noqa: E501
  112. # )
  113. # kv_quant_param = list(np.fromfile(path, dtype=np.float32))
  114. # kv_quant_params.append(kv_quant_param)
  115. # return kv_quant_params
  116. def load_model(self) -> None:
  117. with measure_cuda_memory() as m:
  118. self.model = get_model(
  119. self.model_config,
  120. self.device_config,
  121. lora_config=self.lora_config,
  122. parallel_config=self.parallel_config,
  123. scheduler_config=self.scheduler_config,
  124. )
  125. self.model_memory_usage = m.consumed_memory
  126. tp = get_tensor_model_parallel_world_size()
  127. logger.info(
  128. "Model weights loaded. Memory usage: "
  129. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
  130. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  131. if self.lora_config:
  132. assert (hasattr(self.model, "supported_lora_modules")
  133. and self.model.supported_lora_modules
  134. ), "Model does not support LoRA"
  135. assert hasattr(
  136. self.model,
  137. "embedding_modules"), "Model does not have embedding_modules"
  138. assert hasattr(self.model, "embedding_padding_modules"
  139. ), "Model does not have embedding_padding_modules"
  140. self.lora_manager = LRUCacheWorkerLoRAManager(
  141. self.scheduler_config.max_num_seqs,
  142. self.scheduler_config.max_num_batched_tokens +
  143. self.scheduler_config.max_paddings,
  144. self.vocab_size,
  145. self.lora_config,
  146. self.device,
  147. self.model.embedding_modules,
  148. self.model.embedding_padding_modules,
  149. )
  150. self.model = self.lora_manager.create_lora_manager(self.model)
  151. def set_block_size(self, block_size: int) -> None:
  152. self.block_size = block_size
  153. max_num_blocks = (self.max_context_len_to_capture + block_size -
  154. 1) // block_size
  155. self.graph_block_tables = np.zeros(
  156. (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
  157. def _prepare_prompt(
  158. self,
  159. seq_group_metadata_list: List[SequenceGroupMetadata],
  160. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
  161. List[int], List[int], Set[LoRARequest], ]:
  162. assert len(seq_group_metadata_list) > 0
  163. input_tokens: List[List[int]] = []
  164. input_positions: List[List[int]] = []
  165. slot_mapping: List[List[int]] = []
  166. lora_index_mapping: List[int] = []
  167. lora_prompt_mapping: List[int] = []
  168. lora_requests: Set[LoRARequest] = set()
  169. prompt_lens: List[int] = []
  170. context_lens: List[int] = []
  171. subquery_lens: List[int] = []
  172. prefix_block_tables: List[List[int]] = []
  173. for seq_group_metadata in seq_group_metadata_list:
  174. assert seq_group_metadata.is_prompt
  175. seq_ids = list(seq_group_metadata.seq_data.keys())
  176. assert len(seq_ids) == 1
  177. seq_id = seq_ids[0]
  178. seq_data = seq_group_metadata.seq_data[seq_id]
  179. prompt_tokens = seq_data.get_token_ids()
  180. prompt_len = len(prompt_tokens)
  181. prompt_lens.append(prompt_len)
  182. computed_len = 0
  183. # NOTE: This only works for oooooooxxx style attention.
  184. computed_block_nums = seq_group_metadata.computed_block_nums
  185. if (computed_block_nums is not None
  186. and len(computed_block_nums) > 0
  187. and self.sliding_window is None):
  188. # Prefix is not supported with sliding_window
  189. computed_len = len(computed_block_nums) * self.block_size
  190. prompt_tokens = prompt_tokens[computed_len:]
  191. prefix_block_tables.append(computed_block_nums)
  192. else:
  193. prefix_block_tables.append([])
  194. # actual prompt lens
  195. context_lens.append(computed_len)
  196. subquery_lens.append(prompt_len - computed_len)
  197. input_tokens.append(prompt_tokens)
  198. # NOTE: Here we assume that the first token in the prompt
  199. # is always the first token in the sequence.
  200. input_positions.append(
  201. list(range(computed_len, computed_len + len(prompt_tokens))))
  202. lora_id = seq_group_metadata.lora_int_id
  203. if lora_id > 0:
  204. lora_requests.add(seq_group_metadata.lora_request)
  205. lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
  206. lora_prompt_mapping.extend(
  207. [lora_id] *
  208. (prompt_len - computed_len
  209. if seq_group_metadata.sampling_params.prompt_logprobs else 1))
  210. if seq_group_metadata.block_tables is None:
  211. # During memory profiling, the block tables are not initialized
  212. # yet. In this case, we just use a dummy slot mapping.
  213. slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
  214. continue
  215. # Compute the slot mapping.
  216. slot_mapping.append([])
  217. block_table = seq_group_metadata.block_tables[seq_id]
  218. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  219. # where start_idx is max(0, prompt_len - sliding_window).
  220. # For example, if the prompt len is 10, sliding window is 8, and
  221. # block size is 4, the first two tokens are masked and the slot
  222. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  223. start_idx = 0
  224. if self.sliding_window is not None:
  225. assert computed_len == 0, (
  226. "Prefix caching is currently not supported with "
  227. "sliding window attention")
  228. start_idx = max(0, prompt_len - self.sliding_window)
  229. for i in range(computed_len, prompt_len):
  230. if i < start_idx:
  231. slot_mapping[-1].append(_PAD_SLOT_ID)
  232. continue
  233. block_number = block_table[i // self.block_size]
  234. block_offset = i % self.block_size
  235. slot = block_number * self.block_size + block_offset
  236. slot_mapping[-1].append(slot)
  237. max_prompt_len = max(subquery_lens)
  238. assert max_prompt_len > 0
  239. input_tokens = _make_tensor_with_pad(
  240. input_tokens,
  241. max_prompt_len,
  242. pad=0,
  243. dtype=torch.long,
  244. device=self.device,
  245. )
  246. input_positions = _make_tensor_with_pad(
  247. input_positions,
  248. max_prompt_len,
  249. pad=0,
  250. dtype=torch.long,
  251. device=self.device,
  252. )
  253. slot_mapping = _make_tensor_with_pad(
  254. slot_mapping,
  255. max_prompt_len,
  256. pad=_PAD_SLOT_ID,
  257. dtype=torch.long,
  258. device=self.device,
  259. )
  260. lora_index_mapping = [
  261. _pad_to_max(mapping, max_prompt_len, pad=0)
  262. for mapping in lora_index_mapping
  263. ]
  264. context_lens_tensor = torch.tensor(context_lens,
  265. dtype=torch.int,
  266. device=self.device)
  267. # Prepare prefix block tables
  268. max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
  269. block_tables = _make_tensor_with_pad(
  270. prefix_block_tables,
  271. max_len=max_prompt_block_table_len,
  272. pad=0,
  273. dtype=torch.int,
  274. device=self.device,
  275. )
  276. start_loc_tensor = torch.arange(
  277. 0,
  278. len(prompt_lens) * max_prompt_len,
  279. max_prompt_len,
  280. dtype=torch.long,
  281. device=self.device,
  282. )
  283. prompt_lens_tensor = torch.tensor(prompt_lens,
  284. dtype=torch.long,
  285. device=self.device)
  286. input_metadata = InputMetadata(
  287. is_prompt=True,
  288. slot_mapping=slot_mapping,
  289. prompt_lens=prompt_lens_tensor,
  290. max_seq_len=max_prompt_len,
  291. start_loc=start_loc_tensor,
  292. max_context_len=None,
  293. context_lens=context_lens_tensor,
  294. block_tables=block_tables,
  295. use_cuda_graph=False,
  296. kv_cache_dtype=self.kv_cache_dtype,
  297. # kv_quant_params=self.kv_quant_params,
  298. )
  299. return (
  300. input_tokens,
  301. input_positions,
  302. input_metadata,
  303. prompt_lens,
  304. subquery_lens,
  305. lora_index_mapping,
  306. lora_prompt_mapping,
  307. lora_requests,
  308. )
  309. def _prepare_decode(
  310. self,
  311. seq_group_metadata_list: List[SequenceGroupMetadata],
  312. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
  313. Set[LoRARequest], ]:
  314. assert len(seq_group_metadata_list) > 0
  315. input_tokens: List[List[int]] = []
  316. input_positions: List[List[int]] = []
  317. slot_mapping: List[List[int]] = []
  318. context_lens: List[int] = []
  319. block_tables: List[List[int]] = []
  320. lora_index_mapping: List[int] = []
  321. lora_prompt_mapping: List[int] = []
  322. lora_requests: Set[LoRARequest] = set()
  323. for seq_group_metadata in seq_group_metadata_list:
  324. assert not seq_group_metadata.is_prompt
  325. seq_ids = list(seq_group_metadata.seq_data.keys())
  326. lora_id = seq_group_metadata.lora_int_id
  327. if lora_id > 0:
  328. lora_requests.add(seq_group_metadata.lora_request)
  329. for seq_id in seq_ids:
  330. seq_data = seq_group_metadata.seq_data[seq_id]
  331. generation_token = seq_data.get_last_token_id()
  332. input_tokens.append([generation_token])
  333. seq_len = seq_data.get_len()
  334. position = seq_len - 1
  335. input_positions.append([position])
  336. context_len = (seq_len if self.sliding_window is None else min(
  337. seq_len, self.sliding_window))
  338. context_lens.append(context_len)
  339. block_table = seq_group_metadata.block_tables[seq_id]
  340. block_number = block_table[position // self.block_size]
  341. block_offset = position % self.block_size
  342. slot = block_number * self.block_size + block_offset
  343. slot_mapping.append([slot])
  344. lora_index_mapping.append([lora_id])
  345. lora_prompt_mapping.append(lora_id)
  346. if self.sliding_window is not None:
  347. sliding_window_blocks = (self.sliding_window //
  348. self.block_size)
  349. block_table = block_table[-sliding_window_blocks:]
  350. block_tables.append(block_table)
  351. batch_size = len(input_tokens)
  352. max_context_len = max(context_lens)
  353. use_captured_graph = (
  354. not self.model_config.enforce_eager
  355. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  356. and max_context_len <= self.max_context_len_to_capture)
  357. if use_captured_graph:
  358. # Pad the input tokens, positions, and slot mapping to match the
  359. # batch size of the captured graph.
  360. graph_batch_size = _get_graph_batch_size(batch_size)
  361. assert graph_batch_size >= batch_size
  362. for _ in range(graph_batch_size - batch_size):
  363. input_tokens.append([])
  364. input_positions.append([])
  365. slot_mapping.append([])
  366. context_lens.append(1)
  367. block_tables.append([])
  368. batch_size = graph_batch_size
  369. input_tokens = _make_tensor_with_pad(input_tokens,
  370. max_len=1,
  371. pad=0,
  372. dtype=torch.long,
  373. device=self.device)
  374. input_positions = _make_tensor_with_pad(
  375. input_positions,
  376. max_len=1,
  377. pad=0,
  378. dtype=torch.long,
  379. device=self.device,
  380. )
  381. slot_mapping = _make_tensor_with_pad(
  382. slot_mapping,
  383. max_len=1,
  384. pad=_PAD_SLOT_ID,
  385. dtype=torch.long,
  386. device=self.device,
  387. )
  388. context_lens = torch.tensor(context_lens,
  389. dtype=torch.int,
  390. device=self.device)
  391. if use_captured_graph:
  392. # The shape of graph_block_tables is
  393. # [max batch size, max context len // block size].
  394. input_block_tables = self.graph_block_tables[:batch_size]
  395. for i, block_table in enumerate(block_tables):
  396. if block_table:
  397. input_block_tables[i, :len(block_table)] = block_table
  398. block_tables = torch.tensor(input_block_tables, device=self.device)
  399. else:
  400. max_block_table_len = max(
  401. len(block_table) for block_table in block_tables)
  402. block_tables = _make_tensor_with_pad(
  403. block_tables,
  404. max_len=max_block_table_len,
  405. pad=0,
  406. dtype=torch.int,
  407. device=self.device,
  408. )
  409. lora_index_mapping = [
  410. _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
  411. ]
  412. input_metadata = InputMetadata(
  413. is_prompt=False,
  414. slot_mapping=slot_mapping,
  415. prompt_lens=None,
  416. max_seq_len=None,
  417. start_loc=None,
  418. max_context_len=max_context_len,
  419. context_lens=context_lens,
  420. block_tables=block_tables,
  421. use_cuda_graph=use_captured_graph,
  422. kv_cache_dtype=self.kv_cache_dtype,
  423. # kv_quant_params=self.kv_quant_params,
  424. )
  425. return (
  426. input_tokens,
  427. input_positions,
  428. input_metadata,
  429. lora_index_mapping,
  430. lora_prompt_mapping,
  431. lora_requests,
  432. )
  433. def _prepare_sample(
  434. self,
  435. seq_group_metadata_list: List[SequenceGroupMetadata],
  436. prompt_lens: List[int],
  437. subquery_lens: Optional[List[int]],
  438. ) -> SamplingMetadata:
  439. seq_groups: List[Tuple[List[int], SamplingParams]] = []
  440. selected_token_indices: List[int] = []
  441. generators: List[torch.Generator] = []
  442. selected_token_start_idx = 0
  443. categorized_sample_indices = {t: [] for t in SamplingType}
  444. categorized_sample_indices_start_idx = 0
  445. pin_memory = not self.in_wsl and not self.device_config.is_neuron
  446. max_subquery_len = max(subquery_lens) if subquery_lens else 1
  447. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  448. seq_ids = list(seq_group_metadata.seq_data.keys())
  449. sampling_params = seq_group_metadata.sampling_params
  450. seq_groups.append((seq_ids, sampling_params))
  451. if seq_group_metadata.is_prompt:
  452. assert len(seq_ids) == 1
  453. assert subquery_lens is not None
  454. subquery_len = subquery_lens[i]
  455. if sampling_params.prompt_logprobs is not None:
  456. # NOTE: prompt token positions do not need sample, skip
  457. categorized_sample_indices_start_idx += subquery_len - 1
  458. categorized_sample_indices[
  459. sampling_params.sampling_type].append(
  460. categorized_sample_indices_start_idx)
  461. categorized_sample_indices_start_idx += 1
  462. if sampling_params.prompt_logprobs is not None:
  463. selected_token_indices.extend(
  464. range(
  465. selected_token_start_idx,
  466. selected_token_start_idx + subquery_len - 1,
  467. ))
  468. selected_token_indices.append(selected_token_start_idx +
  469. subquery_len - 1)
  470. selected_token_start_idx += max_subquery_len
  471. if sampling_params.seed is not None:
  472. seq_group_metadata.state.generator = torch.Generator(
  473. device="cuda").manual_seed(sampling_params.seed)
  474. else:
  475. num_seqs = len(seq_ids)
  476. selected_token_indices.extend(
  477. range(
  478. selected_token_start_idx,
  479. selected_token_start_idx + num_seqs,
  480. ))
  481. selected_token_start_idx += num_seqs
  482. categorized_sample_indices[
  483. sampling_params.sampling_type].extend(
  484. range(
  485. categorized_sample_indices_start_idx,
  486. categorized_sample_indices_start_idx + num_seqs,
  487. ))
  488. categorized_sample_indices_start_idx += num_seqs
  489. if sampling_params.seed is not None:
  490. generators.append(seq_group_metadata.state.generator)
  491. selected_token_indices = _async_h2d(
  492. selected_token_indices,
  493. dtype=torch.long,
  494. target_device=self.device,
  495. pin_memory=pin_memory,
  496. )
  497. categorized_sample_indices = {
  498. t: _async_h2d(
  499. seq_ids,
  500. dtype=torch.int,
  501. target_device=self.device,
  502. pin_memory=pin_memory,
  503. )
  504. for t, seq_ids in categorized_sample_indices.items()
  505. }
  506. seq_data: Dict[int, SequenceData] = {}
  507. for seq_group_metadata in seq_group_metadata_list:
  508. seq_data.update(seq_group_metadata.seq_data)
  509. seq_persistence_data: Dict[int, dict] = {}
  510. for grp in seq_group_metadata_list:
  511. seq_persistence_data.update(grp.persistent_data)
  512. sampling_metadata = SamplingMetadata(
  513. seq_groups=seq_groups,
  514. seq_data=seq_data,
  515. prompt_lens=prompt_lens,
  516. selected_token_indices=selected_token_indices,
  517. categorized_sample_indices=categorized_sample_indices,
  518. generators=generators,
  519. persistent_metadata=PersistentMetadata(seq_persistence_data),
  520. )
  521. return sampling_metadata
  522. def prepare_input_tensors(
  523. self,
  524. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  525. ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
  526. Set[int], LoRAMapping, ]:
  527. if self.is_driver_worker:
  528. # NOTE: We assume that all sequences in the group are all prompts or
  529. # all decodes.
  530. is_prompt = seq_group_metadata_list[0].is_prompt
  531. # Prepare input tensors.
  532. if is_prompt:
  533. (
  534. input_tokens,
  535. input_positions,
  536. input_metadata,
  537. prompt_lens,
  538. subquery_lens,
  539. lora_index_mapping,
  540. lora_prompt_mapping,
  541. lora_requests,
  542. ) = self._prepare_prompt(seq_group_metadata_list)
  543. else:
  544. (
  545. input_tokens,
  546. input_positions,
  547. input_metadata,
  548. lora_index_mapping,
  549. lora_prompt_mapping,
  550. lora_requests,
  551. ) = self._prepare_decode(seq_group_metadata_list)
  552. prompt_lens = []
  553. subquery_lens = None
  554. sampling_metadata = self._prepare_sample(seq_group_metadata_list,
  555. prompt_lens,
  556. subquery_lens)
  557. if self.lora_config:
  558. flat_lora_index_mapping = [
  559. item for sublist in lora_index_mapping for item in sublist
  560. ]
  561. lora_mapping = LoRAMapping(
  562. flat_lora_index_mapping,
  563. lora_prompt_mapping,
  564. )
  565. else:
  566. lora_mapping = None
  567. # Broadcast the metadata.
  568. metadata_dict = {
  569. "input_tokens": input_tokens,
  570. "input_positions": input_positions,
  571. "is_prompt": input_metadata.is_prompt,
  572. "slot_mapping": input_metadata.slot_mapping,
  573. "prompt_lens": input_metadata.prompt_lens,
  574. "max_seq_len": input_metadata.max_seq_len,
  575. "start_loc": input_metadata.start_loc,
  576. "max_context_len": input_metadata.max_context_len,
  577. "context_lens": input_metadata.context_lens,
  578. "block_tables": input_metadata.block_tables,
  579. "use_cuda_graph": input_metadata.use_cuda_graph,
  580. "kv_cache_dtype": input_metadata.kv_cache_dtype,
  581. # "kv_quant_params": input_metadata.kv_quant_params,
  582. "selected_token_indices":
  583. sampling_metadata.selected_token_indices,
  584. "lora_requests": lora_requests,
  585. "lora_mapping": lora_mapping,
  586. }
  587. broadcast_tensor_dict(metadata_dict, src=0)
  588. else:
  589. metadata_dict = broadcast_tensor_dict(src=0)
  590. input_tokens = metadata_dict["input_tokens"]
  591. input_positions = metadata_dict["input_positions"]
  592. lora_mapping = metadata_dict["lora_mapping"]
  593. lora_requests = metadata_dict["lora_requests"]
  594. input_metadata = InputMetadata(
  595. is_prompt=metadata_dict["is_prompt"],
  596. slot_mapping=metadata_dict["slot_mapping"],
  597. prompt_lens=metadata_dict["prompt_lens"],
  598. max_seq_len=metadata_dict["max_seq_len"],
  599. start_loc=metadata_dict["start_loc"],
  600. max_context_len=metadata_dict["max_context_len"],
  601. context_lens=metadata_dict["context_lens"],
  602. block_tables=metadata_dict["block_tables"],
  603. use_cuda_graph=metadata_dict["use_cuda_graph"],
  604. kv_cache_dtype=metadata_dict["kv_cache_dtype"],
  605. # kv_quant_params=metadata_dict["kv_quant_params"],
  606. )
  607. sampling_metadata = SamplingMetadata(
  608. seq_groups=None,
  609. seq_data=None,
  610. prompt_lens=None,
  611. selected_token_indices=metadata_dict["selected_token_indices"],
  612. categorized_sample_indices=None,
  613. generators=None,
  614. perform_sampling=False,
  615. )
  616. return (
  617. input_tokens,
  618. input_positions,
  619. input_metadata,
  620. sampling_metadata,
  621. lora_requests,
  622. lora_mapping,
  623. )
  624. @torch.inference_mode()
  625. def execute_model(
  626. self,
  627. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  628. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  629. ) -> Optional[SamplerOutput]:
  630. (
  631. input_tokens,
  632. input_positions,
  633. input_metadata,
  634. sampling_metadata,
  635. lora_requests,
  636. lora_mapping,
  637. ) = self.prepare_input_tensors(seq_group_metadata_list)
  638. if self.lora_config:
  639. self.set_active_loras(lora_requests, lora_mapping)
  640. # Execute the model.
  641. if input_metadata.use_cuda_graph:
  642. graph_batch_size = input_tokens.shape[0]
  643. model_executable = self.graph_runners[graph_batch_size]
  644. else:
  645. model_executable = self.model
  646. hidden_states = model_executable(
  647. input_ids=input_tokens,
  648. positions=input_positions,
  649. kv_caches=kv_caches,
  650. input_metadata=input_metadata,
  651. )
  652. # Sample the next token.
  653. output = self.model.sample(
  654. hidden_states=hidden_states,
  655. sampling_metadata=sampling_metadata,
  656. )
  657. return output
  658. @torch.inference_mode()
  659. def profile_run(self) -> None:
  660. # Enable top-k sampling to reflect the accurate memory usage.
  661. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  662. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  663. max_num_seqs = self.scheduler_config.max_num_seqs
  664. # This represents the maximum number of different requests
  665. # that will have unique loras, an therefore the max amount of memory
  666. # consumption create dummy lora request copies from the lora request
  667. # passed in, which contains a lora from the lora warmup path.
  668. dummy_lora_requests = []
  669. dummy_lora_requests_per_seq = []
  670. if self.lora_config:
  671. for idx in range(self.lora_config.max_loras):
  672. lora_id = idx + 1
  673. dummy_lora_request = LoRARequest(
  674. lora_name=f"warmup_{lora_id}",
  675. lora_int_id=lora_id,
  676. lora_local_path="/not/a/real/path",
  677. )
  678. self.lora_manager.add_dummy_lora(dummy_lora_request,
  679. rank=LORA_WARMUP_RANK)
  680. dummy_lora_requests.append(dummy_lora_request)
  681. dummy_lora_requests_per_seq = [
  682. dummy_lora_requests[idx % len(dummy_lora_requests)]
  683. for idx in range(max_num_seqs)
  684. ]
  685. # Profile memory usage with max_num_sequences sequences and the total
  686. # number of tokens equal to max_num_batched_tokens.
  687. seqs: List[SequenceGroupMetadata] = []
  688. for group_id in range(max_num_seqs):
  689. seq_len = max_num_batched_tokens // max_num_seqs + (
  690. group_id < max_num_batched_tokens % max_num_seqs)
  691. seq_data = SequenceData([0] * seq_len)
  692. seq = SequenceGroupMetadata(
  693. request_id=str(group_id),
  694. is_prompt=True,
  695. seq_data={group_id: seq_data},
  696. sampling_params=sampling_params,
  697. block_tables=None,
  698. persistent_data={},
  699. lora_request=dummy_lora_requests_per_seq[group_id]
  700. if dummy_lora_requests_per_seq else None,
  701. )
  702. seqs.append(seq)
  703. # Run the model with the dummy inputs.
  704. num_layers = self.model_config.get_num_layers(self.parallel_config)
  705. kv_caches = [(None, None)] * num_layers
  706. self.execute_model(seqs, kv_caches)
  707. torch.cuda.synchronize()
  708. return
  709. def remove_all_loras(self) -> bool:
  710. if not self.lora_manager:
  711. raise RuntimeError("LoRA is not enabled.")
  712. return self.lora_manager.remove_all_loras()
  713. def set_active_loras(self, lora_requests: List[LoRARequest],
  714. lora_mapping: LoRAMapping) -> None:
  715. if not self.lora_manager:
  716. raise RuntimeError("LoRA is not enabled.")
  717. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  718. def add_lora(self, lora_request: LoRARequest) -> bool:
  719. if not self.lora_manager:
  720. raise RuntimeError("LoRA is not enabled.")
  721. return self.lora_manager.add_lora(lora_request)
  722. def remove_lora(self, lora_id: int) -> bool:
  723. if not self.lora_manager:
  724. raise RuntimeError("LoRA is not enabled.")
  725. return self.lora_manager.remove_lora(lora_id)
  726. def list_loras(self) -> Set[int]:
  727. if not self.lora_manager:
  728. raise RuntimeError("LoRA is not enabled.")
  729. return self.lora_manager.list_loras()
  730. @torch.inference_mode()
  731. def capture_model(self, kv_caches: List[KVCache]) -> None:
  732. # NOTE: This is a hack to ensure that the NCCL backend is never
  733. # deleted before the CUDA graphs.
  734. self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
  735. assert not self.model_config.enforce_eager
  736. logger.info("Capturing the model for CUDA graphs. This may lead to "
  737. "unexpected consequences if the model is not static. To "
  738. "run the model in eager mode, set 'enforce_eager=True' or "
  739. "use '--enforce-eager' in the CLI.")
  740. logger.warning("CUDA graphs can take additional 1~3 GiB of memory "
  741. "per GPU. If you are running out of memory, consider "
  742. "decreasing `gpu_memory_utilization` or enforcing "
  743. "eager mode.")
  744. start_time = time.perf_counter()
  745. # Prepare dummy inputs. These will be reused for all batch sizes.
  746. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  747. input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
  748. input_positions = torch.zeros(max_batch_size, 1,
  749. dtype=torch.long).cuda()
  750. slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
  751. slot_mapping.fill_(_PAD_SLOT_ID)
  752. context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  753. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  754. graph_batch_size = _get_graph_batch_size(
  755. self.scheduler_config.max_num_seqs)
  756. batch_size_capture_list = [
  757. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  758. ]
  759. # NOTE: There are 3 backends for all-reduce: custom all-reduce
  760. # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
  761. # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
  762. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
  763. # We always prioritize using custom all-reduce kernel but fall back
  764. # to PyTorch or CuPy NCCL if it is disabled or not supported.
  765. # Initialize a new progress bar
  766. progress = get_loading_progress_bar()
  767. task = progress.add_task("[cyan]Capturing graph...",
  768. total=len(batch_size_capture_list))
  769. with progress, custom_all_reduce.capture():
  770. for batch_size in reversed(batch_size_capture_list):
  771. if batch_size > self.scheduler_config.max_num_seqs:
  772. continue
  773. # Create dummy input_metadata.
  774. input_metadata = InputMetadata(
  775. is_prompt=False,
  776. slot_mapping=slot_mapping[:batch_size],
  777. prompt_lens=None,
  778. max_seq_len=None,
  779. start_loc=None,
  780. max_context_len=self.max_context_len_to_capture,
  781. context_lens=context_lens[:batch_size],
  782. block_tables=block_tables[:batch_size],
  783. use_cuda_graph=True,
  784. kv_cache_dtype=self.kv_cache_dtype,
  785. # kv_quant_params=self.kv_quant_params,
  786. )
  787. if self.lora_config:
  788. lora_mapping = LoRAMapping(
  789. [0] * batch_size,
  790. [0] * batch_size,
  791. )
  792. self.set_active_loras(set(), lora_mapping)
  793. graph_runner = CUDAGraphRunner(self.model)
  794. graph_runner.capture(
  795. input_tokens[:batch_size],
  796. input_positions[:batch_size],
  797. kv_caches,
  798. input_metadata,
  799. memory_pool=self.graph_memory_pool,
  800. )
  801. self.graph_memory_pool = graph_runner.graph.pool()
  802. self.graph_runners[batch_size] = graph_runner
  803. # Update the progress bar
  804. progress.update(task, advance=1)
  805. end_time = time.perf_counter()
  806. elapsed_time = end_time - start_time
  807. # This usually takes < 10 seconds.
  808. logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
  809. def __del__(self) -> None:
  810. # Delete the CUDA graphs before deleting the CuPy NCCL communicator.
  811. # NOTE: This is necessary because otherwise deadlocks can
  812. # happen.
  813. # FIXME: This is a bit hacky. Find a more robust solution.
  814. self.graph_runners.clear()
  815. self.cupy_nccl_backend = None
  816. @property
  817. def vocab_size(self) -> int:
  818. return self.model_config.get_vocab_size()
  819. class CUDAGraphRunner:
  820. def __init__(self, model: nn.Module):
  821. self.model = model
  822. self.graph = None
  823. self.input_buffers: Dict[str, torch.Tensor] = {}
  824. self.output_buffers: Dict[str, torch.Tensor] = {}
  825. def capture(
  826. self,
  827. input_ids: torch.Tensor,
  828. positions: torch.Tensor,
  829. kv_caches: List[KVCache],
  830. input_metadata: InputMetadata,
  831. memory_pool,
  832. ) -> None:
  833. assert self.graph is None
  834. # Run the model once without capturing the graph.
  835. # This is to make sure that the captured graph does not include the
  836. # kernel launches for initial benchmarking (e.g., Triton autotune).
  837. with _maybe_cupy_nccl():
  838. self.model(
  839. input_ids,
  840. positions,
  841. kv_caches,
  842. input_metadata,
  843. )
  844. torch.cuda.synchronize()
  845. # Capture the graph.
  846. # NOTE: Python 3.8 does not support multi-line with statements.
  847. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
  848. self.graph = torch.cuda.CUDAGraph()
  849. with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
  850. with _maybe_cupy_nccl():
  851. hidden_states = self.model(
  852. input_ids,
  853. positions,
  854. kv_caches,
  855. input_metadata,
  856. )
  857. torch.cuda.synchronize()
  858. # Save the input and output buffers.
  859. self.input_buffers = {
  860. "input_ids": input_ids,
  861. "positions": positions,
  862. "kv_caches": kv_caches,
  863. "slot_mapping": input_metadata.slot_mapping,
  864. "context_lens": input_metadata.context_lens,
  865. "block_tables": input_metadata.block_tables,
  866. }
  867. self.output_buffers = {"hidden_states": hidden_states}
  868. return
  869. def forward(
  870. self,
  871. input_ids: torch.Tensor,
  872. positions: torch.Tensor,
  873. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  874. input_metadata: InputMetadata,
  875. ) -> torch.Tensor:
  876. # KV caches are fixed tensors, so we don't need to copy them.
  877. del kv_caches
  878. # Copy the input tensors to the input buffers.
  879. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  880. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  881. self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
  882. non_blocking=True)
  883. self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
  884. non_blocking=True)
  885. self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
  886. non_blocking=True)
  887. # Run the graph.
  888. self.graph.replay()
  889. # Return the output tensor.
  890. return self.output_buffers["hidden_states"]
  891. def __call__(self, *args, **kwargs):
  892. return self.forward(*args, **kwargs)
  893. @contextlib.contextmanager
  894. def _maybe_cupy_nccl():
  895. if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
  896. with with_cupy_nccl_for_all_reduce():
  897. yield
  898. else:
  899. yield
  900. def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
  901. assert len(x) <= max_len
  902. return x + [pad] * (max_len - len(x))
  903. def _make_tensor_with_pad(
  904. x: List[List[int]],
  905. max_len: int,
  906. pad: int,
  907. dtype: torch.dtype,
  908. device: Optional[Union[str, torch.device]],
  909. ) -> torch.Tensor:
  910. padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
  911. return torch.tensor(padded_x, dtype=dtype, device=device)
  912. def _get_graph_batch_size(batch_size: int) -> int:
  913. if batch_size <= 2:
  914. return batch_size
  915. elif batch_size <= 4:
  916. return 4
  917. else:
  918. return (batch_size + 7) // 8 * 8
  919. def _async_h2d(
  920. data: list,
  921. dtype: torch.dtype,
  922. target_device: Union[str, torch.device],
  923. pin_memory: bool,
  924. ) -> torch.Tensor:
  925. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  926. return t.to(device=target_device, non_blocking=True)