model_runner.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153
  1. import time
  2. from enum import IntEnum
  3. from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from loguru import logger
  8. from aphrodite.attention import (AttentionMetadata, AttentionMetadataPerStage,
  9. get_attn_backend)
  10. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  11. LoRAConfig, ModelConfig, ParallelConfig,
  12. SchedulerConfig, VisionLanguageConfig)
  13. from aphrodite.common.sampling_params import SamplingParams
  14. from aphrodite.common.sequence import (MultiModalData, SamplerOutput,
  15. SequenceData, SequenceGroupMetadata)
  16. from aphrodite.common.utils import (CudaMemoryProfiler,
  17. get_kv_cache_torch_dtype, is_hip,
  18. is_pin_memory_available,
  19. make_tensor_with_pad)
  20. from aphrodite.distributed import broadcast_tensor_dict
  21. from aphrodite.distributed.communication_op import graph_capture, graph_mode
  22. from aphrodite.distributed.parallel_state import \
  23. get_tensor_model_parallel_world_size
  24. from aphrodite.lora.layers import LoRAMapping
  25. from aphrodite.lora.request import LoRARequest
  26. from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
  27. from aphrodite.modeling import SamplingMetadata
  28. from aphrodite.modeling.model_loader import get_model
  29. _PAD_SLOT_ID = -1
  30. LORA_WARMUP_RANK = 8
  31. _BATCH_SIZE_ALIGNMENT = 8
  32. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
  33. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
  34. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  35. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  36. ]
  37. class PreparePromptMetadata(NamedTuple):
  38. input_tokens: List[int]
  39. input_positions: List[int]
  40. attn_metadata: Optional[AttentionMetadataPerStage]
  41. seq_lens: List[int]
  42. query_lens: List[int]
  43. lora_index_mapping: List[int]
  44. lora_prompt_mapping: List[int]
  45. lora_requests: Set[LoRARequest]
  46. multi_modal_input: Optional[torch.Tensor]
  47. slot_mapping: List[int]
  48. @classmethod
  49. def empty(cls):
  50. return PreparePromptMetadata(
  51. input_tokens=[],
  52. input_positions=[],
  53. attn_metadata=None,
  54. seq_lens=[],
  55. query_lens=[],
  56. lora_index_mapping=[],
  57. lora_prompt_mapping=[],
  58. lora_requests=set(),
  59. multi_modal_input=None,
  60. slot_mapping=[],
  61. )
  62. class PrepareDecodeMetadata(NamedTuple):
  63. input_tokens: List[int]
  64. input_positions: List[int]
  65. attn_metadata: Optional[AttentionMetadata]
  66. lora_index_mapping: List[int]
  67. lora_prompt_mapping: List[int]
  68. lora_requests: Set[LoRARequest]
  69. slot_mapping: List[int]
  70. @classmethod
  71. def empty(cls):
  72. return PrepareDecodeMetadata(
  73. input_tokens=[],
  74. input_positions=[],
  75. attn_metadata=None,
  76. lora_index_mapping=[],
  77. lora_prompt_mapping=[],
  78. lora_requests=set(),
  79. slot_mapping=[],
  80. )
  81. # How batches are constructed.
  82. class BatchType(IntEnum):
  83. # Every batch is prefill.
  84. PREFILL = 0
  85. # Every batch is decode.
  86. DECODE = 1
  87. # Batch is a mixture of prefill and decode.
  88. MIXED = 2
  89. class ModelRunner:
  90. def __init__(
  91. self,
  92. model_config: ModelConfig,
  93. parallel_config: ParallelConfig,
  94. scheduler_config: SchedulerConfig,
  95. device_config: DeviceConfig,
  96. cache_config: CacheConfig,
  97. load_config: LoadConfig,
  98. lora_config: Optional[LoRAConfig],
  99. kv_cache_dtype: Optional[str] = "auto",
  100. is_driver_worker: bool = False,
  101. vision_language_config: Optional[VisionLanguageConfig] = None,
  102. ):
  103. self.model_config = model_config
  104. self.parallel_config = parallel_config
  105. self.scheduler_config = scheduler_config
  106. self.device_config = device_config
  107. self.cache_config = cache_config
  108. self.lora_config = lora_config
  109. self.load_config = load_config
  110. self.is_driver_worker = is_driver_worker
  111. self.vision_language_config = vision_language_config
  112. self.device = self.device_config.device
  113. self.pin_memory = is_pin_memory_available()
  114. self.kv_cache_dtype = kv_cache_dtype
  115. self.sliding_window = model_config.get_sliding_window()
  116. self.block_size = cache_config.block_size
  117. self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
  118. self.graph_runners: Dict[int, CUDAGraphRunner] = {}
  119. self.graph_memory_pool: Optional[Tuple[
  120. int, int]] = None # Set during graph capture.
  121. # When using CUDA graph, the input block tables must be padded to
  122. # max_seq_len_to_capture. However, creating the block table in
  123. # Python can be expensive. To optimize this, we cache the block table
  124. # in numpy and only copy the actual input content at every iteration.
  125. # The shape of the cached block table will be
  126. # (max batch size to capture, max context len to capture / block size).
  127. self.graph_block_tables = np.zeros(
  128. (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
  129. dtype=np.int32)
  130. self.attn_backend = get_attn_backend(self.model_config.dtype)
  131. # Lazy initialization
  132. self.model: torch.nn.Module # Set after load_model
  133. # Set if the backend is flashinfer.
  134. self.flashinfer_workspace_buffer: torch.Tensor
  135. # Set after load_model.
  136. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
  137. def load_model(self) -> None:
  138. with CudaMemoryProfiler() as m:
  139. self.model = get_model(
  140. model_config=self.model_config,
  141. device_config=self.device_config,
  142. load_config=self.load_config,
  143. lora_config=self.lora_config,
  144. vision_language_config=self.vision_language_config,
  145. parallel_config=self.parallel_config,
  146. scheduler_config=self.scheduler_config,
  147. )
  148. self.model_memory_usage = m.consumed_memory
  149. tp = get_tensor_model_parallel_world_size()
  150. logger.info(
  151. "Model weights loaded. Memory usage: "
  152. f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
  153. f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
  154. if self.lora_config:
  155. assert hasattr(self.model, "supported_lora_modules"
  156. ) and self.model.supported_lora_modules, (
  157. "Model does not support LoRA")
  158. assert hasattr(
  159. self.model,
  160. "embedding_modules"), "Model does not have embedding_modules"
  161. assert hasattr(self.model, "embedding_padding_modules"
  162. ), "Model does not have embedding_padding_modules"
  163. self.lora_manager = LRUCacheWorkerLoRAManager(
  164. self.scheduler_config.max_num_seqs,
  165. self.scheduler_config.max_num_batched_tokens, self.vocab_size,
  166. self.lora_config, self.device, self.model.embedding_modules,
  167. self.model.embedding_padding_modules)
  168. self.model = self.lora_manager.create_lora_manager(self.model)
  169. if self.kv_cache_dtype == "fp8" and is_hip():
  170. # Currently scaled KV cache is only enabled on ROCm
  171. if self.model_config.quantization_param_path is not None:
  172. if callable(getattr(self.model, "load_kv_cache_scales", None)):
  173. self.model.load_kv_cache_scales(
  174. self.model_config.quantization_param_path)
  175. else:
  176. raise RuntimeError(
  177. "Using FP8 KV cache and scaling factors provided but "
  178. "model %s does not support loading scaling factors.",
  179. self.model.__class__)
  180. else:
  181. logger.warning(
  182. "Using FP8 KV cache but no scaling factors "
  183. "provided. Defaulting to scaling factors of 1.0. "
  184. "This may lead to less accurate results!")
  185. elif self.model_config.quantization_param_path is not None:
  186. logger.warning("KV cache scaling factors provided, "
  187. "but the KV cache data type is not FP8. "
  188. "KV cache scaling factors will not be used.")
  189. def get_max_block_per_batch(self) -> int:
  190. block_size = self.block_size
  191. return (self.max_seq_len_to_capture + block_size - 1) // block_size
  192. def _prepare_prompt(
  193. self,
  194. seq_group_metadata_list: List[SequenceGroupMetadata],
  195. ) -> PreparePromptMetadata:
  196. input_tokens: List[int] = []
  197. input_positions: List[int] = []
  198. slot_mapping: List[int] = []
  199. lora_index_mapping: List[int] = []
  200. lora_prompt_mapping: List[int] = []
  201. lora_requests: Set[LoRARequest] = set()
  202. seq_lens: List[int] = []
  203. context_lens: List[int] = []
  204. query_lens: List[int] = []
  205. prefix_block_tables: List[List[int]] = []
  206. multi_modal_input_list: List[torch.Tensor] = []
  207. if len(seq_group_metadata_list) == 0:
  208. return PreparePromptMetadata.empty()
  209. for seq_group_metadata in seq_group_metadata_list:
  210. assert seq_group_metadata.is_prompt
  211. seq_ids = list(seq_group_metadata.seq_data.keys())
  212. assert len(seq_ids) == 1
  213. seq_id = seq_ids[0]
  214. computed_block_nums = seq_group_metadata.computed_block_nums
  215. if (self.scheduler_config is not None
  216. and self.scheduler_config.chunked_prefill_enabled
  217. and not (computed_block_nums is None
  218. or computed_block_nums == [])):
  219. raise RuntimeError(
  220. "chunked prefill cannot be used with prefix caching "
  221. "now.")
  222. token_chunk_size = seq_group_metadata.token_chunk_size
  223. seq_data = seq_group_metadata.seq_data[seq_id]
  224. context_len = seq_data.get_num_computed_tokens()
  225. # We should use get_len here because in case of preemption
  226. # it contains output tokens.
  227. seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
  228. prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
  229. seq_lens.append(seq_len)
  230. # NOTE: This only works for oooooooxxx style attention.
  231. if computed_block_nums is not None and len(
  232. computed_block_nums) > 0 and self.sliding_window is None:
  233. # Prefix is not supported with sliding_window
  234. context_len = len(computed_block_nums) * self.block_size
  235. prompt_tokens = prompt_tokens[context_len:]
  236. prefix_block_tables.append(computed_block_nums)
  237. elif self.scheduler_config.chunked_prefill_enabled:
  238. if seq_group_metadata.block_tables is not None:
  239. # Prefill has chunked before.
  240. block_table = seq_group_metadata.block_tables[seq_id]
  241. prefix_block_tables.append(block_table)
  242. else:
  243. # The first prefill.
  244. prefix_block_tables.append([])
  245. else:
  246. prefix_block_tables.append([])
  247. # Right now, prefill start is always 0. However, this
  248. # assumption can be changed once chunked prefill is introduced.
  249. assert context_len == 0
  250. # actual prompt lens
  251. context_lens.append(context_len)
  252. query_lens.append(seq_len - context_len)
  253. input_tokens.extend(prompt_tokens)
  254. # NOTE: Here we assume that the first token in the prompt
  255. # is always the first token in the sequence.
  256. input_positions.extend(list(range(context_len, seq_len)))
  257. lora_id = seq_group_metadata.lora_int_id
  258. if lora_id > 0:
  259. lora_requests.add(seq_group_metadata.lora_request)
  260. lora_index_mapping += [lora_id] * (seq_len - context_len)
  261. lora_prompt_mapping.extend([lora_id] * (
  262. seq_len - context_len if seq_group_metadata.sampling_params
  263. and seq_group_metadata.sampling_params.prompt_logprobs else 1))
  264. if seq_group_metadata.multi_modal_data:
  265. multi_modal_input_list.append(
  266. seq_group_metadata.multi_modal_data.data)
  267. if _is_block_tables_empty(seq_group_metadata.block_tables):
  268. # During memory profiling, the block tables are not initialized
  269. # yet. In this case, we just use a dummy slot mapping.
  270. # In embeddings, the block tables are {seq_id: None}
  271. slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  272. continue
  273. # Compute the slot mapping.
  274. block_table = seq_group_metadata.block_tables[seq_id]
  275. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  276. # where start_idx is max(0, seq_len - sliding_window).
  277. # For example, if the prompt len is 10, sliding window is 8, and
  278. # block size is 4, the first two tokens are masked and the slot
  279. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  280. start_idx = 0
  281. if self.sliding_window is not None:
  282. assert context_len == 0, (
  283. "Prefix caching is currently not supported with "
  284. "sliding window attention")
  285. start_idx = max(0, seq_len - self.sliding_window)
  286. for i in range(context_len, seq_len):
  287. if i < start_idx:
  288. slot_mapping.append(_PAD_SLOT_ID)
  289. continue
  290. block_number = block_table[i // self.block_size]
  291. block_offset = i % self.block_size
  292. slot = block_number * self.block_size + block_offset
  293. slot_mapping.append(slot)
  294. max_query_len = max(query_lens)
  295. max_seq_len = max(seq_lens)
  296. assert max_query_len > 0
  297. context_lens_tensor = torch.tensor(context_lens,
  298. dtype=torch.int,
  299. device=self.device)
  300. if multi_modal_input_list:
  301. assert self.vision_language_config, (
  302. "Multi-modal inputs are only supported by "
  303. "vision language models.")
  304. multi_modal_input = torch.cat(multi_modal_input_list,
  305. dim=0).to(self.device)
  306. else:
  307. multi_modal_input = None
  308. # Prepare prefix block tables
  309. max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
  310. block_tables = make_tensor_with_pad(
  311. prefix_block_tables,
  312. max_len=max_prompt_block_table_len,
  313. pad=0,
  314. dtype=torch.int,
  315. device=self.device,
  316. )
  317. # Query length can be shorter than key (i.e., prompt) when prefill
  318. # is chunked or prefix cached.
  319. query_lens_tensor = torch.tensor(query_lens,
  320. dtype=torch.long,
  321. device=self.device)
  322. subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
  323. dtype=torch.int32,
  324. device=self.device)
  325. seq_lens_tensor = torch.tensor(seq_lens,
  326. dtype=torch.int,
  327. device=self.device)
  328. seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
  329. dtype=torch.int32,
  330. device=self.device)
  331. torch.cumsum(query_lens_tensor,
  332. dim=0,
  333. dtype=subquery_start_loc.dtype,
  334. out=subquery_start_loc[1:])
  335. torch.cumsum(seq_lens_tensor,
  336. dim=0,
  337. dtype=seq_start_loc.dtype,
  338. out=seq_start_loc[1:])
  339. if self.attn_backend.get_name() == "flashinfer":
  340. attn_metadata = self.attn_backend.make_metadata(
  341. is_prompt=True,
  342. use_cuda_graph=False,
  343. seq_start_loc=seq_start_loc,
  344. max_seq_len=max_seq_len,
  345. block_tables=block_tables)
  346. else:
  347. attn_metadata = self.attn_backend.make_metadata(
  348. is_prompt=True,
  349. seq_lens=seq_lens,
  350. seq_lens_tensor=seq_lens_tensor,
  351. max_query_len=max_query_len,
  352. max_seq_len=max_seq_len,
  353. subquery_start_loc=subquery_start_loc,
  354. seq_start_loc=seq_start_loc,
  355. context_lens_tensor=context_lens_tensor,
  356. block_tables=block_tables,
  357. use_cuda_graph=False,
  358. )
  359. return PreparePromptMetadata(
  360. input_tokens=input_tokens,
  361. input_positions=input_positions,
  362. attn_metadata=attn_metadata,
  363. seq_lens=seq_lens,
  364. query_lens=query_lens,
  365. lora_index_mapping=lora_index_mapping,
  366. lora_prompt_mapping=lora_prompt_mapping,
  367. lora_requests=lora_requests,
  368. multi_modal_input=multi_modal_input,
  369. slot_mapping=slot_mapping,
  370. )
  371. def _prepare_decode(
  372. self,
  373. seq_group_metadata_list: List[SequenceGroupMetadata],
  374. ) -> PrepareDecodeMetadata:
  375. input_tokens: List[int] = []
  376. input_positions: List[int] = []
  377. slot_mapping: List[int] = []
  378. seq_lens: List[int] = []
  379. block_tables: List[List[int]] = []
  380. lora_index_mapping: List[int] = []
  381. lora_prompt_mapping: List[int] = []
  382. lora_requests: Set[LoRARequest] = set()
  383. # The following fields are only for flashinfer
  384. # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
  385. # for the precise definition of the following fields.
  386. # An example:
  387. # request 1, page indices [0, 5, 8]
  388. # request 2, page indices [1, 6, 7]
  389. # request 3, page indices [3, 4]
  390. # paged_kv_indices is a concatenation of page indices of all requests:
  391. # [0, 5, 8, 1, 6, 7, 3, 4]
  392. # paged_kv_indptr is used to index into paged_kv_indices:
  393. # [0, 3, 6, 8]
  394. paged_kv_indices: List[int] = []
  395. # 0 at the beginning of paged_kv_indptr indicates the start of the
  396. # first request’s page indices in the paged_kv_indices list.
  397. paged_kv_indptr: List[int] = [0]
  398. # paged_kv_last_page_len is the length of the last page of each request
  399. paged_kv_last_page_len: List[int] = []
  400. if len(seq_group_metadata_list) == 0:
  401. return PrepareDecodeMetadata.empty()
  402. for seq_group_metadata in seq_group_metadata_list:
  403. assert not seq_group_metadata.is_prompt
  404. assert seq_group_metadata.token_chunk_size == 1
  405. seq_ids = list(seq_group_metadata.seq_data.keys())
  406. lora_id = seq_group_metadata.lora_int_id
  407. if lora_id > 0:
  408. lora_requests.add(seq_group_metadata.lora_request)
  409. for seq_id in seq_ids:
  410. seq_data = seq_group_metadata.seq_data[seq_id]
  411. generation_token = seq_data.get_last_token_id()
  412. input_tokens.append(generation_token)
  413. seq_len = seq_data.get_len()
  414. position = seq_len - 1
  415. input_positions.append(position)
  416. seq_len = seq_len if self.sliding_window is None else min(
  417. seq_len, self.sliding_window)
  418. seq_lens.append(seq_len)
  419. block_table = seq_group_metadata.block_tables[seq_id]
  420. block_number = block_table[position // self.block_size]
  421. block_offset = position % self.block_size
  422. slot = block_number * self.block_size + block_offset
  423. slot_mapping.append(slot)
  424. lora_index_mapping.append(lora_id)
  425. lora_prompt_mapping.append(lora_id)
  426. if self.sliding_window is not None:
  427. sliding_window_blocks = (self.sliding_window //
  428. self.block_size)
  429. block_table = block_table[-sliding_window_blocks:]
  430. block_tables.append(block_table)
  431. paged_kv_indices.extend(block_table)
  432. paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
  433. last_page_len = seq_data.get_len() % self.block_size
  434. if last_page_len == 0:
  435. last_page_len = self.block_size
  436. paged_kv_last_page_len.append(last_page_len)
  437. # Aphrodite uses cuda graph only for decoding requests.
  438. # See `capture_model` API for more details.
  439. # For decoding requests, batch_size == input_tokens.
  440. batch_size = len(input_tokens)
  441. max_seq_len = max(seq_lens)
  442. use_captured_graph = (not self.model_config.enforce_eager
  443. and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
  444. and max_seq_len <= self.max_seq_len_to_capture)
  445. if use_captured_graph:
  446. graph_batch_size = _get_graph_batch_size(batch_size)
  447. assert graph_batch_size >= batch_size
  448. for _ in range(graph_batch_size - batch_size):
  449. input_tokens.append(0)
  450. input_positions.append(0)
  451. slot_mapping.append(_PAD_SLOT_ID)
  452. seq_lens.append(1)
  453. block_tables.append([])
  454. lora_index_mapping.append(0)
  455. batch_size = graph_batch_size
  456. seq_lens_tensor = torch.tensor(seq_lens,
  457. dtype=torch.int,
  458. device=self.device)
  459. if use_captured_graph:
  460. # When using cuda-graph all these tensors should be
  461. # padded.
  462. assert seq_lens_tensor.shape[0] == len(input_tokens)
  463. assert seq_lens_tensor.shape[0] == len(input_positions)
  464. assert seq_lens_tensor.shape[0] == len(slot_mapping)
  465. # The shape of graph_block_tables is
  466. # [max batch size, max context len // block size].
  467. input_block_tables = self.graph_block_tables[:batch_size]
  468. for i, block_table in enumerate(block_tables):
  469. if block_table:
  470. input_block_tables[i, :len(block_table)] = block_table
  471. block_tables = torch.tensor(input_block_tables, device=self.device)
  472. else:
  473. max_block_table_len = max(
  474. len(block_table) for block_table in block_tables)
  475. block_tables = make_tensor_with_pad(
  476. block_tables,
  477. max_len=max_block_table_len,
  478. pad=0,
  479. dtype=torch.int,
  480. device=self.device,
  481. )
  482. if self.attn_backend.get_name() == "flashinfer":
  483. if not hasattr(self, "flashinfer_workspace_buffer"):
  484. # Allocate 16MB workspace buffer
  485. # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
  486. self.flashinfer_workspace_buffer = torch.empty(
  487. 16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
  488. paged_kv_indptr = torch.tensor(paged_kv_indptr,
  489. dtype=torch.int,
  490. device=self.device)
  491. paged_kv_indices = torch.tensor(paged_kv_indices,
  492. dtype=torch.int,
  493. device=self.device)
  494. paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
  495. dtype=torch.int,
  496. device=self.device)
  497. kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
  498. self.model_config.dtype)
  499. attn_metadata = self.attn_backend.make_metadata(
  500. is_prompt=False,
  501. use_cuda_graph=False,
  502. workspace_buffer=self.flashinfer_workspace_buffer,
  503. paged_kv_indptr=paged_kv_indptr,
  504. paged_kv_indices=paged_kv_indices,
  505. paged_kv_last_page_len=paged_kv_last_page_len,
  506. num_qo_heads=self.model_config.get_num_attention_heads(
  507. self.parallel_config),
  508. num_kv_heads=self.model_config.get_num_kv_heads(
  509. self.parallel_config),
  510. head_dim=self.model_config.get_head_size(),
  511. page_size=self.block_size,
  512. data_type=kv_cache_dtype)
  513. else:
  514. attn_metadata = self.attn_backend.make_metadata(
  515. is_prompt=False,
  516. seq_lens=None,
  517. seq_lens_tensor=seq_lens_tensor,
  518. max_query_len=None,
  519. max_seq_len=max_seq_len,
  520. subquery_start_loc=None,
  521. seq_start_loc=None,
  522. context_lens_tensor=None,
  523. block_tables=block_tables,
  524. use_cuda_graph=use_captured_graph,
  525. )
  526. return PrepareDecodeMetadata(
  527. input_tokens=input_tokens,
  528. input_positions=input_positions,
  529. attn_metadata=attn_metadata,
  530. lora_index_mapping=lora_index_mapping,
  531. lora_prompt_mapping=lora_prompt_mapping,
  532. lora_requests=lora_requests,
  533. slot_mapping=slot_mapping,
  534. )
  535. def prepare_input_tensors(
  536. self,
  537. seq_group_metadata_list: List[SequenceGroupMetadata],
  538. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  539. Set[LoRARequest], LoRAMapping, torch.Tensor]:
  540. if self.is_driver_worker:
  541. prefill_reqs = []
  542. decode_reqs = []
  543. for seq_group_meta in seq_group_metadata_list:
  544. if seq_group_meta.is_prompt:
  545. prefill_reqs.append(seq_group_meta)
  546. else:
  547. decode_reqs.append(seq_group_meta)
  548. # Prepare input tensors.
  549. (
  550. input_tokens,
  551. input_positions,
  552. prefill_attn_metadata,
  553. seq_lens,
  554. query_lens,
  555. lora_index_mapping,
  556. lora_prompt_mapping,
  557. lora_requests,
  558. multi_modal_input,
  559. slot_mapping,
  560. ) = self._prepare_prompt(prefill_reqs)
  561. (
  562. decode_input_tokens,
  563. decode_input_positions,
  564. decode_attn_metadata,
  565. decode_lora_index_mapping,
  566. decode_lora_prompt_mapping,
  567. decode_lora_requests,
  568. decode_slot_mapping,
  569. ) = self._prepare_decode(decode_reqs)
  570. sampling_metadata = SamplingMetadata.prepare(
  571. seq_group_metadata_list, seq_lens, query_lens, self.device,
  572. self.pin_memory)
  573. if not self.scheduler_config.chunked_prefill_enabled:
  574. assert (len(prefill_reqs) and len(decode_reqs)) == 0
  575. num_prefills = len(seq_lens)
  576. num_prefill_tokens = len(input_tokens)
  577. num_decode_tokens = len(decode_input_tokens)
  578. # Coalesce tensors. Note that attn_metadata is currently not
  579. # coalesced for simplicity.
  580. input_tokens.extend(decode_input_tokens)
  581. input_positions.extend(decode_input_positions)
  582. slot_mapping.extend(decode_slot_mapping)
  583. lora_index_mapping.extend(decode_lora_index_mapping)
  584. lora_prompt_mapping.extend(decode_lora_prompt_mapping)
  585. lora_requests.update(decode_lora_requests)
  586. input_tokens = torch.tensor(input_tokens,
  587. dtype=torch.long,
  588. device=self.device)
  589. input_positions = torch.tensor(input_positions,
  590. dtype=torch.long,
  591. device=self.device)
  592. slot_mapping = torch.tensor(slot_mapping,
  593. dtype=torch.long,
  594. device=self.device)
  595. if self.lora_config:
  596. lora_mapping = LoRAMapping(
  597. lora_index_mapping,
  598. lora_prompt_mapping,
  599. )
  600. else:
  601. lora_mapping = None
  602. # Broadcast the metadata.
  603. # If batch contains both prefill and decode, it sends 2 broadcasts.
  604. # If it only contains 1 type, it triggers a single broadcast.
  605. if (prefill_attn_metadata is not None
  606. and decode_attn_metadata is not None):
  607. batch_type = BatchType.MIXED
  608. elif prefill_attn_metadata is not None:
  609. batch_type = BatchType.PREFILL
  610. else:
  611. batch_type = BatchType.DECODE
  612. metadata_dict = {
  613. "input_tokens": input_tokens,
  614. "input_positions": input_positions,
  615. "selected_token_indices":
  616. sampling_metadata.selected_token_indices,
  617. "lora_requests": lora_requests,
  618. "lora_mapping": lora_mapping,
  619. "multi_modal_input": multi_modal_input,
  620. "num_prefill_tokens": num_prefill_tokens,
  621. "num_decode_tokens": num_decode_tokens,
  622. "slot_mapping": slot_mapping,
  623. "num_prefills": num_prefills,
  624. "batch_type": batch_type,
  625. }
  626. if prefill_attn_metadata is not None:
  627. metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
  628. else:
  629. assert decode_attn_metadata is not None
  630. metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
  631. broadcast_tensor_dict(metadata_dict, src=0)
  632. # Broadcast decode attn metadata for mixed batch type.
  633. # The additional broadcast costs 300us overhead on 4 A10 GPUs.
  634. # We can potentially reduce the overhead by coelescing tensors.
  635. if batch_type == BatchType.MIXED:
  636. assert decode_attn_metadata is not None
  637. metadata_dict = decode_attn_metadata.asdict_zerocopy()
  638. broadcast_tensor_dict(metadata_dict, src=0)
  639. else:
  640. metadata_dict = broadcast_tensor_dict(src=0)
  641. input_tokens = metadata_dict.pop("input_tokens")
  642. input_positions = metadata_dict.pop("input_positions")
  643. slot_mapping = metadata_dict.pop("slot_mapping")
  644. num_prefills = metadata_dict.pop("num_prefills")
  645. selected_token_indices = metadata_dict.pop(
  646. "selected_token_indices")
  647. lora_mapping = metadata_dict.pop("lora_mapping")
  648. lora_requests = metadata_dict.pop("lora_requests")
  649. multi_modal_input = metadata_dict.pop("multi_modal_input")
  650. num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
  651. num_decode_tokens = metadata_dict.pop("num_decode_tokens")
  652. batch_type = metadata_dict.pop("batch_type")
  653. # Create an attention metadata.
  654. prefill_attn_metadata = None
  655. decode_attn_metadata = None
  656. if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
  657. prefill_attn_metadata = self.attn_backend.make_metadata(
  658. **metadata_dict)
  659. else:
  660. decode_attn_metadata = self.attn_backend.make_metadata(
  661. **metadata_dict)
  662. sampling_metadata = SamplingMetadata(
  663. seq_groups=None,
  664. selected_token_indices=selected_token_indices,
  665. categorized_sample_indices=None,
  666. num_prompts=0,
  667. )
  668. # if it is a mixed batch, decode attn_metadata is broadcasted
  669. # separately.
  670. if batch_type == BatchType.MIXED:
  671. metadata_dict = broadcast_tensor_dict(src=0)
  672. decode_attn_metadata = self.attn_backend.make_metadata(
  673. **metadata_dict)
  674. attn_metadata = AttentionMetadata(
  675. num_prefills=num_prefills,
  676. slot_mapping=slot_mapping,
  677. num_prefill_tokens=num_prefill_tokens,
  678. num_decode_tokens=num_decode_tokens,
  679. prefill_metadata=prefill_attn_metadata,
  680. decode_metadata=decode_attn_metadata,
  681. kv_cache_dtype=self.kv_cache_dtype,
  682. )
  683. return (input_tokens, input_positions, attn_metadata,
  684. sampling_metadata, lora_requests, lora_mapping,
  685. multi_modal_input)
  686. @torch.inference_mode()
  687. def execute_model(
  688. self,
  689. seq_group_metadata_list: List[SequenceGroupMetadata],
  690. kv_caches: List[torch.Tensor],
  691. ) -> Optional[SamplerOutput]:
  692. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  693. lora_requests, lora_mapping, multi_modal_input
  694. ) = self.prepare_input_tensors(seq_group_metadata_list)
  695. if self.lora_config:
  696. self.set_active_loras(lora_requests, lora_mapping)
  697. # Currently cuda graph is only supported by the decode phase.
  698. prefill_meta = attn_metadata.prefill_metadata
  699. decode_meta = attn_metadata.decode_metadata
  700. if prefill_meta is None and decode_meta.use_cuda_graph:
  701. graph_batch_size = input_tokens.shape[0]
  702. model_executable = self.graph_runners[graph_batch_size]
  703. else:
  704. model_executable = self.model
  705. execute_model_kwargs = {
  706. "input_ids": input_tokens,
  707. "positions": input_positions,
  708. "kv_caches": kv_caches,
  709. "attn_metadata": attn_metadata,
  710. }
  711. if self.vision_language_config:
  712. execute_model_kwargs.update({"image_input": multi_modal_input})
  713. hidden_states = model_executable(**execute_model_kwargs)
  714. # Compute the logits.
  715. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  716. # Only perform sampling in the driver worker.
  717. if not self.is_driver_worker:
  718. return None
  719. # Sample the next token.
  720. output = self.model.sample(
  721. logits=logits,
  722. sampling_metadata=sampling_metadata,
  723. )
  724. return output
  725. @torch.inference_mode()
  726. def profile_run(self) -> None:
  727. # Enable top-k sampling to reflect the accurate memory usage.
  728. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  729. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  730. max_num_seqs = self.scheduler_config.max_num_seqs
  731. # This represents the maximum number of different requests
  732. # that will have unique loras, an therefore the max amount of memory
  733. # consumption create dummy lora request copies from the lora request
  734. # passed in, which contains a lora from the lora warmup path.
  735. dummy_lora_requests = []
  736. dummy_lora_requests_per_seq = []
  737. if self.lora_config:
  738. assert self.lora_manager is not None
  739. with self.lora_manager.dummy_lora_cache():
  740. for idx in range(self.lora_config.max_loras):
  741. lora_id = idx + 1
  742. dummy_lora_request = LoRARequest(
  743. lora_name=f"warmup_{lora_id}",
  744. lora_int_id=lora_id,
  745. lora_local_path="/not/a/real/path",
  746. )
  747. self.lora_manager.add_dummy_lora(dummy_lora_request,
  748. rank=LORA_WARMUP_RANK)
  749. dummy_lora_requests.append(dummy_lora_request)
  750. dummy_lora_requests_per_seq = [
  751. dummy_lora_requests[idx % len(dummy_lora_requests)]
  752. for idx in range(max_num_seqs)
  753. ]
  754. # Profile memory usage with max_num_sequences sequences and the total
  755. # number of tokens equal to max_num_batched_tokens.
  756. seqs: List[SequenceGroupMetadata] = []
  757. # Additional GPU memory may be needed for vision encoding, which needs
  758. # to be accounted for when calculating the GPU blocks for
  759. # Aphrodite blocker manager.
  760. # To exercise the worst scenario for GPU memory consumption,
  761. # the number of seqs (batch_size) is chosen to maximize the number
  762. # of images processed.
  763. if self.vision_language_config:
  764. max_num_seqs = min(
  765. max_num_seqs,
  766. int(max_num_batched_tokens /
  767. self.vision_language_config.image_feature_size))
  768. for group_id in range(max_num_seqs):
  769. seq_len = (max_num_batched_tokens // max_num_seqs +
  770. (group_id < max_num_batched_tokens % max_num_seqs))
  771. seq_data, fake_multi_modal_input = _prepare_fake_inputs(
  772. seq_len, self.vision_language_config)
  773. seq = SequenceGroupMetadata(
  774. request_id=str(group_id),
  775. is_prompt=True,
  776. seq_data={group_id: seq_data},
  777. sampling_params=sampling_params,
  778. block_tables=None,
  779. lora_request=dummy_lora_requests_per_seq[group_id]
  780. if dummy_lora_requests_per_seq else None,
  781. multi_modal_data=fake_multi_modal_input,
  782. )
  783. seqs.append(seq)
  784. # Run the model with the dummy inputs.
  785. num_layers = self.model_config.get_num_layers(self.parallel_config)
  786. kv_caches = [None] * num_layers
  787. self.execute_model(seqs, kv_caches)
  788. torch.cuda.synchronize()
  789. return
  790. def remove_all_loras(self):
  791. if not self.lora_manager:
  792. raise RuntimeError("LoRA is not enabled.")
  793. self.lora_manager.remove_all_loras()
  794. def set_active_loras(self, lora_requests: Set[LoRARequest],
  795. lora_mapping: LoRAMapping) -> None:
  796. if not self.lora_manager:
  797. raise RuntimeError("LoRA is not enabled.")
  798. self.lora_manager.set_active_loras(lora_requests, lora_mapping)
  799. def add_lora(self, lora_request: LoRARequest) -> bool:
  800. if not self.lora_manager:
  801. raise RuntimeError("LoRA is not enabled.")
  802. return self.lora_manager.add_lora(lora_request)
  803. def remove_lora(self, lora_id: int) -> bool:
  804. if not self.lora_manager:
  805. raise RuntimeError("LoRA is not enabled.")
  806. return self.lora_manager.remove_lora(lora_id)
  807. def list_loras(self) -> Set[int]:
  808. if not self.lora_manager:
  809. raise RuntimeError("LoRA is not enabled.")
  810. return self.lora_manager.list_loras()
  811. @torch.inference_mode()
  812. def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
  813. """Cuda graph capture a model.
  814. Note that CUDA graph's performance gain is negligible if number
  815. of batched tokens are larger than 200. And since CUDA graph
  816. requires fixed sized tensors, supporting large/variable batch
  817. size requires high GPU memory overhead. Thus, vLLM only captures
  818. decoding requests. Mixed batch (chunked prefill + decoding) or
  819. prefill requests are not captured.
  820. Since it is used for decoding-only, it assumes there's only 1 token
  821. per sequence in the batch.
  822. """
  823. assert not self.model_config.enforce_eager
  824. logger.info("Capturing the model for CUDA graphs. This may lead to "
  825. "unexpected consequences if the model is not static. To "
  826. "run the model in eager mode, set 'enforce_eager=True' or "
  827. "use '--enforce-eager' in the CLI.")
  828. logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
  829. "If you are running out of memory, consider decreasing "
  830. "`gpu_memory_utilization` or enforcing eager mode. "
  831. "You can also reduce the `max_num_seqs` as needed "
  832. "to decrease memory usage.")
  833. start_time = time.perf_counter()
  834. # Prepare dummy inputs. These will be reused for all batch sizes.
  835. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
  836. input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  837. input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
  838. slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
  839. slot_mapping.fill_(_PAD_SLOT_ID)
  840. seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
  841. block_tables = torch.from_numpy(self.graph_block_tables).cuda()
  842. graph_batch_size = _get_graph_batch_size(
  843. self.scheduler_config.max_num_seqs)
  844. batch_size_capture_list = [
  845. bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
  846. ]
  847. with graph_capture():
  848. # NOTE: Capturing the largest batch size first may help reduce the
  849. # memory usage of CUDA graph.
  850. for batch_size in reversed(batch_size_capture_list):
  851. # Create dummy attn_metadata.
  852. decode_metadata = self.attn_backend.make_metadata(
  853. is_prompt=False,
  854. seq_lens=None,
  855. seq_lens_tensor=seq_lens[:batch_size],
  856. max_query_len=None,
  857. max_seq_len=self.max_seq_len_to_capture,
  858. subquery_start_loc=None,
  859. seq_start_loc=None,
  860. context_lens_tensor=None,
  861. block_tables=block_tables[:batch_size],
  862. use_cuda_graph=True,
  863. )
  864. attn_metadata = AttentionMetadata(
  865. num_prefills=0,
  866. num_prefill_tokens=0,
  867. num_decode_tokens=batch_size,
  868. slot_mapping=slot_mapping[:batch_size],
  869. prefill_metadata=None,
  870. decode_metadata=decode_metadata,
  871. kv_cache_dtype=self.kv_cache_dtype,
  872. )
  873. if self.lora_config:
  874. lora_mapping = LoRAMapping(
  875. [0] * batch_size,
  876. [0] * batch_size,
  877. )
  878. self.set_active_loras(set(), lora_mapping)
  879. graph_runner = CUDAGraphRunner(self.model)
  880. graph_runner.capture(
  881. input_tokens[:batch_size],
  882. input_positions[:batch_size],
  883. kv_caches,
  884. attn_metadata,
  885. memory_pool=self.graph_memory_pool,
  886. )
  887. self.graph_memory_pool = graph_runner.graph.pool()
  888. self.graph_runners[batch_size] = graph_runner
  889. end_time = time.perf_counter()
  890. elapsed_time = end_time - start_time
  891. # This usually takes < 10 seconds.
  892. logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
  893. def __del__(self) -> None:
  894. # Delete the CUDA graphs before deleting the pynccl communicator.
  895. # NOTE: This is necessary because otherwise deadlocks can
  896. # happen.
  897. # FIXME: This is a bit hacky. Find a more robust solution.
  898. # TODO: when we get enough user feedback that pynccl is
  899. # more stable than cupy, we can remove this, e.g. in v0.4.1.
  900. self.graph_runners.clear()
  901. self.pynccl_backend = None
  902. @property
  903. def vocab_size(self) -> int:
  904. return self.model_config.get_vocab_size()
  905. class CUDAGraphRunner:
  906. def __init__(self, model: nn.Module):
  907. self.model = model
  908. self.input_buffers: Dict[str, torch.Tensor] = {}
  909. self.output_buffers: Dict[str, torch.Tensor] = {}
  910. self._graph: Optional[torch.cuda.CUDAGraph] = None
  911. @property
  912. def graph(self):
  913. assert self._graph is not None
  914. return self._graph
  915. def capture(
  916. self,
  917. input_ids: torch.Tensor,
  918. positions: torch.Tensor,
  919. kv_caches: List[torch.Tensor],
  920. attn_metadata: AttentionMetadata,
  921. memory_pool,
  922. **kwargs,
  923. ) -> None:
  924. assert self._graph is None
  925. # Run the model once without capturing the graph.
  926. # This is to make sure that the captured graph does not include the
  927. # kernel launches for initial benchmarking (e.g., Triton autotune).
  928. with graph_mode():
  929. self.model(
  930. input_ids,
  931. positions,
  932. kv_caches,
  933. attn_metadata,
  934. **kwargs,
  935. )
  936. torch.cuda.synchronize()
  937. # Capture the graph.
  938. # NOTE: Python 3.8 does not support multi-line with statements.
  939. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
  940. self._graph = torch.cuda.CUDAGraph()
  941. with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
  942. with graph_mode():
  943. hidden_states = self.model(
  944. input_ids,
  945. positions,
  946. kv_caches,
  947. attn_metadata,
  948. **kwargs,
  949. )
  950. torch.cuda.synchronize()
  951. # Save the input and output buffers.
  952. self.input_buffers = {
  953. "input_ids": input_ids,
  954. "positions": positions,
  955. "kv_caches": kv_caches,
  956. "slot_mapping": attn_metadata.slot_mapping,
  957. "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
  958. "block_tables": attn_metadata.decode_metadata.block_tables,
  959. }
  960. self.output_buffers = {"hidden_states": hidden_states}
  961. return
  962. def forward(
  963. self,
  964. input_ids: torch.Tensor,
  965. positions: torch.Tensor,
  966. kv_caches: List[torch.Tensor],
  967. attn_metadata: AttentionMetadata,
  968. **kwargs,
  969. ) -> torch.Tensor:
  970. # KV caches are fixed tensors, so we don't need to copy them.
  971. del kv_caches
  972. # Copy the input tensors to the input buffers.
  973. self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
  974. self.input_buffers["positions"].copy_(positions, non_blocking=True)
  975. self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
  976. non_blocking=True)
  977. self.input_buffers["seq_lens_tensor"].copy_(
  978. attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
  979. self.input_buffers["block_tables"].copy_(
  980. attn_metadata.decode_metadata.block_tables, non_blocking=True)
  981. # Run the graph.
  982. self.graph.replay()
  983. # Return the output tensor.
  984. return self.output_buffers["hidden_states"]
  985. def __call__(self, *args, **kwargs):
  986. return self.forward(*args, **kwargs)
  987. def _get_graph_batch_size(batch_size: int) -> int:
  988. """Returns the padded batch size given actual batch size.
  989. Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
  990. 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
  991. """
  992. if batch_size <= 2:
  993. return batch_size
  994. elif batch_size <= 4:
  995. return 4
  996. else:
  997. return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
  998. _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
  999. def _prepare_fake_inputs(
  1000. seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
  1001. """Prepare fake inputs for profile run."""
  1002. if vision_language_config:
  1003. prompt_tokens = [
  1004. vision_language_config.image_token_id
  1005. ] * vision_language_config.image_feature_size + [0] * (
  1006. seq_len - vision_language_config.image_feature_size)
  1007. fake_image_input = MultiModalData(
  1008. type=MultiModalData.Type.IMAGE,
  1009. data=torch.zeros(vision_language_config.image_input_shape,
  1010. dtype=torch.float16))
  1011. else:
  1012. prompt_tokens = [0] * seq_len
  1013. fake_image_input = None
  1014. return SequenceData(prompt_tokens), fake_image_input
  1015. def _is_block_tables_empty(block_tables: Union[None, Dict]):
  1016. """
  1017. Check if block_tables is None or a dictionary with all None values.
  1018. """
  1019. if block_tables is None:
  1020. return True
  1021. if isinstance(block_tables, dict) and all(
  1022. value is None for value in block_tables.values()):
  1023. return True
  1024. return False