xpu_model_runner.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. from typing import List, Optional, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from loguru import logger
  5. from aphrodite.attention import get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig, VisionLanguageConfig)
  9. from aphrodite.common.sampling_params import SamplingParams
  10. from aphrodite.common.sequence import (SamplerOutput, SequenceData,
  11. SequenceGroupMetadata)
  12. from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad
  13. from aphrodite.distributed import broadcast_tensor_dict
  14. from aphrodite.modeling.model_loader import get_model
  15. from aphrodite.task_handler.model_runner import (AttentionMetadata,
  16. SamplingMetadata)
  17. _PAD_SLOT_ID = -1
  18. _BATCH_SIZE_ALIGNMENT = 8
  19. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
  20. _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
  21. ]
  22. class XPUModelRunner:
  23. def __init__(
  24. self,
  25. model_config: ModelConfig,
  26. parallel_config: ParallelConfig,
  27. scheduler_config: SchedulerConfig,
  28. device_config: DeviceConfig,
  29. cache_config: CacheConfig,
  30. load_config: LoadConfig,
  31. lora_config: Optional[LoRAConfig],
  32. vision_language_config: Optional[VisionLanguageConfig],
  33. kv_cache_dtype: Optional[str] = "auto",
  34. is_driver_worker: bool = False,
  35. *args,
  36. **kwargs,
  37. ):
  38. self.model_config = model_config
  39. self.parallel_config = parallel_config
  40. self.scheduler_config = scheduler_config
  41. self.lora_config = lora_config
  42. self.load_config = load_config
  43. self.cache_config = cache_config
  44. self.vision_language_config = vision_language_config
  45. self.is_driver_worker = is_driver_worker
  46. self.sliding_window = model_config.get_sliding_window()
  47. self.device_config = device_config
  48. self.device = self.device_config.device
  49. self.kv_cache_dtype = kv_cache_dtype
  50. self.block_size = cache_config.block_size
  51. self.max_context_len_to_capture = (
  52. self.model_config.max_context_len_to_capture
  53. if self.model_config is not None else 0)
  54. self.attn_backend = get_attn_backend(
  55. self.model_config.get_num_attention_heads(self.parallel_config),
  56. self.model_config.get_head_size(),
  57. self.model_config.get_num_kv_heads(self.parallel_config),
  58. self.model_config.get_sliding_window(),
  59. self.model_config.dtype,
  60. self.kv_cache_dtype,
  61. self.block_size,
  62. )
  63. # Lazy initialization.
  64. self.model: nn.Module # Set after init_Model
  65. def load_model(self) -> None:
  66. with CudaMemoryProfiler() as m:
  67. self.model = get_model(
  68. model_config=self.model_config,
  69. device_config=self.device_config,
  70. load_config=self.load_config,
  71. lora_config=self.lora_config,
  72. vision_language_config=self.vision_language_config,
  73. parallel_config=self.parallel_config,
  74. scheduler_config=self.scheduler_config,
  75. cache_config=self.cache_config,
  76. )
  77. self.model_memory_usage = m.consumed_memory
  78. logger.info("Loading model weights took "
  79. f"{self.model_memory_usage / float(2**30):.4f} GB")
  80. @property
  81. def vocab_size(self) -> int:
  82. return self.model_config.get_vocab_size()
  83. @torch.inference_mode()
  84. def profile_run(self) -> None:
  85. # Enable top-k sampling to reflect the accurate memory usage.
  86. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
  87. max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
  88. max_num_seqs = self.scheduler_config.max_num_seqs
  89. # Profile memory usage with max_num_sequences sequences and the total
  90. # number of tokens equal to max_num_batched_tokens.
  91. seqs: List[SequenceGroupMetadata] = []
  92. # Additional GPU memory may be needed for vision encoding, which needs
  93. # to be accounted for when calculating the GPU blocks for
  94. # Aphrodite blocker manager.
  95. # To exercise the worst scenario for GPU memory consumption,
  96. # the number of seqs (batch_size) is chosen to maximize the number
  97. # of images processed.
  98. for group_id in range(max_num_seqs):
  99. seq_len = (max_num_batched_tokens // max_num_seqs +
  100. (group_id < max_num_batched_tokens % max_num_seqs))
  101. seq_data = SequenceData([0] * seq_len)
  102. dummy_multi_modal_data = None
  103. seq = SequenceGroupMetadata(
  104. request_id=str(group_id),
  105. is_prompt=True,
  106. seq_data={group_id: seq_data},
  107. sampling_params=sampling_params,
  108. block_tables=None,
  109. lora_request=None,
  110. multi_modal_data=dummy_multi_modal_data,
  111. )
  112. seqs.append(seq)
  113. # Run the model with the dummy inputs.
  114. num_layers = self.model_config.get_num_layers(self.parallel_config)
  115. kv_caches = [None] * num_layers
  116. self.execute_model(seqs, kv_caches)
  117. torch.xpu.synchronize()
  118. return
  119. def prepare_input_tensors(
  120. self,
  121. seq_group_metadata_list: List[SequenceGroupMetadata],
  122. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  123. Optional[torch.Tensor]]:
  124. multi_modal_input = None
  125. if self.is_driver_worker:
  126. # NOTE: We assume that all sequences in the group are all prompts or
  127. # all decodes.
  128. is_prompt = seq_group_metadata_list[0].is_prompt
  129. # Prepare input tensors.
  130. if is_prompt:
  131. (input_tokens, input_positions, attn_metadata, seq_lens,
  132. multi_modal_input
  133. ) = self._prepare_prompt(seq_group_metadata_list)
  134. else:
  135. (input_tokens, input_positions,
  136. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  137. seq_lens = []
  138. sampling_metadata = SamplingMetadata.prepare(
  139. seq_group_metadata_list,
  140. seq_lens,
  141. # subquery_lens is not needed if chunked prefill is not
  142. # supported. Since CPU worker doesn't support chunked prefill
  143. # just use seq_lens instead.
  144. seq_lens,
  145. self.device,
  146. pin_memory=False)
  147. # Broadcast the metadata.
  148. metadata_dict = {
  149. "input_tokens": input_tokens,
  150. "input_positions": input_positions,
  151. "selected_token_indices":
  152. sampling_metadata.selected_token_indices,
  153. }
  154. metadata_dict.update(attn_metadata.asdict_zerocopy())
  155. broadcast_tensor_dict(metadata_dict, src=0)
  156. else:
  157. metadata_dict = broadcast_tensor_dict(src=0)
  158. input_tokens = metadata_dict.pop("input_tokens")
  159. input_positions = metadata_dict.pop("input_positions")
  160. selected_token_indices = metadata_dict.pop(
  161. "selected_token_indices")
  162. attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
  163. sampling_metadata = SamplingMetadata(
  164. seq_groups=None,
  165. selected_token_indices=selected_token_indices,
  166. categorized_sample_indices=None,
  167. num_prompts=0,
  168. )
  169. return (input_tokens, input_positions, attn_metadata,
  170. sampling_metadata, multi_modal_input)
  171. def _prepare_decode(
  172. self,
  173. seq_group_metadata_list: List[SequenceGroupMetadata],
  174. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  175. assert len(seq_group_metadata_list) > 0
  176. input_tokens: List[int] = []
  177. input_positions: List[int] = []
  178. slot_mapping: List[int] = []
  179. seq_lens: List[int] = []
  180. block_tables: List[List[int]] = []
  181. for seq_group_metadata in seq_group_metadata_list:
  182. assert not seq_group_metadata.is_prompt
  183. assert seq_group_metadata.token_chunk_size == 1
  184. seq_ids = list(seq_group_metadata.seq_data.keys())
  185. for seq_id in seq_ids:
  186. seq_data = seq_group_metadata.seq_data[seq_id]
  187. generation_token = seq_data.get_last_token_id()
  188. input_tokens.append(generation_token)
  189. seq_len = seq_data.get_len()
  190. position = seq_len - 1
  191. input_positions.append(position)
  192. seq_len = seq_len if self.sliding_window is None else min(
  193. seq_len, self.sliding_window)
  194. seq_lens.append(seq_len)
  195. block_table = seq_group_metadata.block_tables[seq_id]
  196. block_number = block_table[position // self.block_size]
  197. block_offset = position % self.block_size
  198. slot = block_number * self.block_size + block_offset
  199. slot_mapping.append(slot)
  200. if self.sliding_window is not None:
  201. sliding_window_blocks = (self.sliding_window //
  202. self.block_size)
  203. block_table = block_table[-sliding_window_blocks:]
  204. block_tables.append(block_table)
  205. max_decode_seq_len = max(seq_lens)
  206. input_tokens = torch.tensor(input_tokens,
  207. dtype=torch.long,
  208. device=self.device)
  209. input_positions = torch.tensor(input_positions,
  210. dtype=torch.long,
  211. device=self.device)
  212. slot_mapping = torch.tensor(slot_mapping,
  213. dtype=torch.long,
  214. device=self.device)
  215. seq_lens_tensor = torch.tensor(seq_lens,
  216. dtype=torch.int,
  217. device=self.device)
  218. max_block_table_len = max(
  219. len(block_table) for block_table in block_tables)
  220. block_tables = make_tensor_with_pad(
  221. block_tables,
  222. max_len=max_block_table_len,
  223. pad=0,
  224. dtype=torch.int,
  225. device=self.device,
  226. )
  227. attn_metadata = self.attn_backend.make_metadata(
  228. is_prompt=False,
  229. slot_mapping=slot_mapping,
  230. seq_lens=seq_lens,
  231. seqlen_q=None,
  232. max_seqlen=None,
  233. seq_lens_tensor=seq_lens_tensor,
  234. max_decode_seq_len=max_decode_seq_len,
  235. num_prefill_tokens=0,
  236. num_decode_tokens=len(input_tokens),
  237. num_prefills=0,
  238. block_tables=block_tables,
  239. )
  240. return (
  241. input_tokens,
  242. input_positions,
  243. attn_metadata,
  244. )
  245. @torch.inference_mode()
  246. def execute_model(
  247. self,
  248. seq_group_metadata_list: List[SequenceGroupMetadata],
  249. kv_caches: List[torch.Tensor],
  250. ) -> Optional[SamplerOutput]:
  251. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  252. multi_modal_input
  253. ) = self.prepare_input_tensors(seq_group_metadata_list)
  254. model_executable = self.model
  255. execute_model_kwargs = {
  256. "input_ids": input_tokens,
  257. "positions": input_positions,
  258. "kv_caches": kv_caches,
  259. "attn_metadata": attn_metadata,
  260. }
  261. if self.vision_language_config:
  262. execute_model_kwargs.update({"image_input": multi_modal_input})
  263. hidden_states = model_executable(**execute_model_kwargs)
  264. # Compute the logits.
  265. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  266. # Only perform sampling in the driver worker.
  267. if not self.is_driver_worker:
  268. return None
  269. # Sample the next token.
  270. output = self.model.sample(
  271. logits=logits,
  272. sampling_metadata=sampling_metadata,
  273. )
  274. return output
  275. def _prepare_prompt(
  276. self,
  277. seq_group_metadata_list: List[SequenceGroupMetadata],
  278. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  279. Optional[torch.Tensor]]:
  280. assert len(seq_group_metadata_list) > 0
  281. input_tokens: List[int] = []
  282. input_positions: List[int] = []
  283. slot_mapping: List[int] = []
  284. seq_lens: List[int] = []
  285. multi_modal_input_list: List[torch.Tensor] = []
  286. for seq_group_metadata in seq_group_metadata_list:
  287. assert seq_group_metadata.is_prompt
  288. seq_ids = list(seq_group_metadata.seq_data.keys())
  289. assert len(seq_ids) == 1
  290. seq_id = seq_ids[0]
  291. seq_data = seq_group_metadata.seq_data[seq_id]
  292. prompt_tokens = seq_data.get_token_ids()
  293. computed_len = seq_data.get_num_computed_tokens()
  294. seq_len = len(prompt_tokens)
  295. seq_lens.append(seq_len) # Prompt token num
  296. input_tokens.extend(prompt_tokens) # Token ids
  297. # Token position ids
  298. # NOTE: Here we assume that the first token in the prompt
  299. # is always the first token in the sequence.
  300. input_positions.extend(list(range(computed_len, seq_len)))
  301. if seq_group_metadata.multi_modal_data:
  302. multi_modal_input_list.append(
  303. seq_group_metadata.multi_modal_data.data)
  304. if seq_group_metadata.block_tables is None:
  305. # During memory profiling, the block tables are not initialized
  306. # yet. In this case, we just use a dummy slot mapping.
  307. slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
  308. continue
  309. # Compute the slot mapping.
  310. block_table = seq_group_metadata.block_tables[seq_id]
  311. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  312. # where start_idx is max(0, seq_len - sliding_window).
  313. # For example, if the prompt len is 10, sliding window is 8, and
  314. # block size is 4, the first two tokens are masked and the slot
  315. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  316. start_idx = 0
  317. if self.sliding_window is not None:
  318. start_idx = max(0, seq_len - self.sliding_window)
  319. for i in range(computed_len, seq_len):
  320. if i < start_idx:
  321. slot_mapping.append(_PAD_SLOT_ID)
  322. continue
  323. block_number = block_table[i //
  324. self.block_size] # type: ignore
  325. block_offset = i % self.block_size # type: ignore
  326. slot = block_number * self.block_size + block_offset
  327. slot_mapping.append(slot)
  328. if multi_modal_input_list:
  329. assert self.vision_language_config, (
  330. "Multi-modal inputs are only supported by "
  331. "vision language models.")
  332. multi_modal_input = torch.cat(multi_modal_input_list,
  333. dim=0).to(self.device)
  334. else:
  335. multi_modal_input = None
  336. num_prompt_tokens = len(input_tokens)
  337. input_tokens = torch.tensor(input_tokens,
  338. dtype=torch.long,
  339. device=self.device) # type: ignore
  340. input_positions = torch.tensor(input_positions,
  341. dtype=torch.long,
  342. device=self.device) # type: ignore
  343. slot_mapping = torch.tensor(slot_mapping,
  344. dtype=torch.long,
  345. device=self.device) # type: ignore
  346. max_seqlen = max(seq_lens)
  347. tmp = [0]
  348. tmp.extend(seq_lens)
  349. seqlen = torch.tensor(tmp)
  350. seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
  351. attn_metadata = self.attn_backend.make_metadata(
  352. is_prompt=True,
  353. slot_mapping=slot_mapping,
  354. seq_lens=seq_lens,
  355. seqlen_q=seqlen_q,
  356. max_seqlen=max_seqlen,
  357. seq_lens_tensor=None,
  358. max_decode_seq_len=None,
  359. num_prefills=len(seq_lens),
  360. num_prefill_tokens=num_prompt_tokens,
  361. num_decode_tokens=0,
  362. block_tables=torch.tensor([], device=self.device, dtype=torch.int),
  363. )
  364. return (input_tokens, input_positions, attn_metadata, seq_lens,
  365. multi_modal_input)