cpu_model_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. from collections import defaultdict
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from aphrodite.attention import AttentionMetadata, get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. SchedulerConfig, VisionLanguageConfig)
  9. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  10. from aphrodite.common.utils import make_tensor_with_pad
  11. from aphrodite.distributed import broadcast_tensor_dict
  12. from aphrodite.modeling import SamplingMetadata
  13. from aphrodite.modeling.model_loader import get_model
  14. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  15. _PAD_SLOT_ID = -1
  16. class CPUModelRunner:
  17. def __init__(
  18. self,
  19. model_config: ModelConfig,
  20. parallel_config: ParallelConfig,
  21. scheduler_config: SchedulerConfig,
  22. device_config: DeviceConfig,
  23. cache_config: CacheConfig,
  24. load_config: LoadConfig,
  25. lora_config: Optional[LoRAConfig],
  26. vision_language_config: Optional[VisionLanguageConfig],
  27. kv_cache_dtype: Optional[str] = "auto",
  28. is_driver_worker: bool = False,
  29. *args,
  30. **kwargs,
  31. ):
  32. self.model_config = model_config
  33. self.parallel_config = parallel_config
  34. self.scheduler_config = scheduler_config
  35. # Currently, CPU worker doesn't support chunked prefill.
  36. assert self.scheduler_config.chunked_prefill_enabled is False
  37. self.device_config = device_config
  38. self.cache_config = cache_config
  39. self.lora_config = lora_config
  40. self.vision_language_config = vision_language_config
  41. self.load_config = load_config
  42. self.is_driver_worker = is_driver_worker
  43. self.device = self.device_config.device
  44. self.kv_cache_dtype = kv_cache_dtype
  45. self.sliding_window = model_config.get_sliding_window()
  46. self.block_size = cache_config.block_size
  47. self.attn_backend = get_attn_backend(
  48. self.model_config.get_num_attention_heads(self.parallel_config),
  49. self.model_config.get_head_size(),
  50. self.model_config.get_num_kv_heads(self.parallel_config),
  51. self.model_config.get_sliding_window(),
  52. self.model_config.dtype,
  53. self.kv_cache_dtype,
  54. self.block_size,
  55. )
  56. # Create processor for multi-modal data
  57. if self.vision_language_config is not None:
  58. self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
  59. .create_input_processor(
  60. self.model_config,
  61. self.vision_language_config,
  62. )
  63. else:
  64. self.multi_modal_input_processor = None
  65. # Lazy initialization.
  66. self.model: nn.Module # Set after init_Model
  67. def load_model(self) -> None:
  68. self.model = get_model(
  69. model_config=self.model_config,
  70. load_config=self.load_config,
  71. device_config=self.device_config,
  72. vision_language_config=self.vision_language_config,
  73. lora_config=self.lora_config,
  74. parallel_config=self.parallel_config,
  75. scheduler_config=self.scheduler_config,
  76. cache_config=self.cache_config)
  77. def _prepare_prompt(
  78. self,
  79. seq_group_metadata_list: List[SequenceGroupMetadata],
  80. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], Dict[
  81. str, torch.Tensor]]:
  82. assert len(seq_group_metadata_list) > 0
  83. input_tokens: List[int] = []
  84. input_positions: List[int] = []
  85. slot_mapping: List[int] = []
  86. seq_lens: List[int] = []
  87. multi_modal_kwargs_list: Dict[str,
  88. List[torch.Tensor]] = defaultdict(list)
  89. for seq_group_metadata in seq_group_metadata_list:
  90. assert seq_group_metadata.is_prompt
  91. seq_ids = list(seq_group_metadata.seq_data.keys())
  92. assert len(seq_ids) == 1
  93. seq_id = seq_ids[0]
  94. seq_data = seq_group_metadata.seq_data[seq_id]
  95. prompt_tokens = seq_data.get_token_ids()
  96. computed_len = seq_data.get_num_computed_tokens()
  97. seq_len = len(prompt_tokens)
  98. seq_lens.append(seq_len) # Prompt token num
  99. input_tokens.extend(prompt_tokens) # Token ids
  100. # Token position ids
  101. # NOTE: Here we assume that the first token in the prompt
  102. # is always the first token in the sequence.
  103. input_positions.extend(list(range(computed_len, seq_len)))
  104. mm_data = seq_group_metadata.multi_modal_data
  105. if mm_data is not None:
  106. # Process multi-modal data
  107. if self.multi_modal_input_processor is None:
  108. raise ValueError(
  109. "Multi-modal inputs are only supported by "
  110. "vision language models.")
  111. mm_kwargs = self.multi_modal_input_processor(mm_data)
  112. for k, v in mm_kwargs.items():
  113. multi_modal_kwargs_list[k].append(v)
  114. # Compute the slot mapping.
  115. block_table = seq_group_metadata.block_tables[seq_id]
  116. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  117. # where start_idx is max(0, seq_len - sliding_window).
  118. # For example, if the prompt len is 10, sliding window is 8, and
  119. # block size is 4, the first two tokens are masked and the slot
  120. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  121. start_idx = 0
  122. if self.sliding_window is not None:
  123. start_idx = max(0, seq_len - self.sliding_window)
  124. for i in range(computed_len, seq_len):
  125. if i < start_idx:
  126. slot_mapping.append(_PAD_SLOT_ID)
  127. continue
  128. block_number = block_table[i //
  129. self.block_size] # type: ignore
  130. block_offset = i % self.block_size # type: ignore
  131. slot = block_number * self.block_size + block_offset
  132. slot_mapping.append(slot)
  133. multi_modal_kwargs = {
  134. k: torch.cat(v, dim=0).to(self.device)
  135. for k, v in multi_modal_kwargs_list.items()
  136. }
  137. num_prompt_tokens = len(input_tokens)
  138. input_tokens = torch.tensor(input_tokens,
  139. dtype=torch.long,
  140. device=self.device) # type: ignore
  141. input_positions = torch.tensor(input_positions,
  142. dtype=torch.long,
  143. device=self.device) # type: ignore
  144. slot_mapping = torch.tensor(slot_mapping,
  145. dtype=torch.long,
  146. device=self.device) # type: ignore
  147. attn_metadata = self.attn_backend.make_metadata(
  148. is_prompt=True,
  149. seq_lens=seq_lens,
  150. seq_lens_tensor=None,
  151. max_decode_seq_len=None,
  152. num_prefills=len(seq_lens),
  153. num_prefill_tokens=num_prompt_tokens,
  154. num_decode_tokens=0,
  155. block_tables=torch.tensor([]),
  156. slot_mapping=slot_mapping,
  157. )
  158. return (input_tokens, input_positions, attn_metadata, seq_lens,
  159. multi_modal_kwargs)
  160. def _prepare_decode(
  161. self,
  162. seq_group_metadata_list: List[SequenceGroupMetadata],
  163. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  164. assert len(seq_group_metadata_list) > 0
  165. input_tokens: List[int] = []
  166. input_positions: List[int] = []
  167. slot_mapping: List[int] = []
  168. seq_lens: List[int] = []
  169. block_tables: List[List[int]] = []
  170. for seq_group_metadata in seq_group_metadata_list:
  171. assert not seq_group_metadata.is_prompt
  172. assert seq_group_metadata.token_chunk_size == 1
  173. seq_ids = list(seq_group_metadata.seq_data.keys())
  174. for seq_id in seq_ids:
  175. seq_data = seq_group_metadata.seq_data[seq_id]
  176. generation_token = seq_data.get_last_token_id()
  177. input_tokens.append(generation_token)
  178. seq_len = seq_data.get_len()
  179. position = seq_len - 1
  180. input_positions.append(position)
  181. seq_len = seq_len if self.sliding_window is None else min(
  182. seq_len, self.sliding_window)
  183. seq_lens.append(seq_len)
  184. block_table = seq_group_metadata.block_tables[seq_id]
  185. block_number = block_table[position // self.block_size]
  186. block_offset = position % self.block_size
  187. slot = block_number * self.block_size + block_offset
  188. slot_mapping.append(slot)
  189. if self.sliding_window is not None:
  190. sliding_window_blocks = (self.sliding_window //
  191. self.block_size)
  192. block_table = block_table[-sliding_window_blocks:]
  193. block_tables.append(block_table)
  194. max_decode_seq_len = max(seq_lens)
  195. input_tokens = torch.tensor(input_tokens,
  196. dtype=torch.long,
  197. device=self.device)
  198. input_positions = torch.tensor(input_positions,
  199. dtype=torch.long,
  200. device=self.device)
  201. slot_mapping = torch.tensor(slot_mapping,
  202. dtype=torch.long,
  203. device=self.device)
  204. seq_lens_tensor = torch.tensor(seq_lens,
  205. dtype=torch.int,
  206. device=self.device)
  207. max_block_table_len = max(
  208. len(block_table) for block_table in block_tables)
  209. block_tables = make_tensor_with_pad(
  210. block_tables,
  211. max_len=max_block_table_len,
  212. pad=0,
  213. dtype=torch.int,
  214. device=self.device,
  215. )
  216. attn_metadata = self.attn_backend.make_metadata(
  217. is_prompt=False,
  218. slot_mapping=slot_mapping,
  219. seq_lens=seq_lens,
  220. seq_lens_tensor=seq_lens_tensor,
  221. max_decode_seq_len=max_decode_seq_len,
  222. num_prefill_tokens=0,
  223. num_decode_tokens=len(input_tokens),
  224. num_prefills=0,
  225. block_tables=block_tables,
  226. )
  227. return (
  228. input_tokens,
  229. input_positions,
  230. attn_metadata,
  231. )
  232. def prepare_input_tensors(
  233. self,
  234. seq_group_metadata_list: List[SequenceGroupMetadata],
  235. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  236. Optional[Dict[str, torch.Tensor]]]:
  237. multi_modal_kwargs = None
  238. if self.is_driver_worker:
  239. # NOTE: We assume that all sequences in the group are all prompts or
  240. # all decodes.
  241. is_prompt = seq_group_metadata_list[0].is_prompt
  242. # Prepare input tensors.
  243. if is_prompt:
  244. (input_tokens, input_positions, attn_metadata, seq_lens,
  245. multi_modal_kwargs
  246. ) = self._prepare_prompt(seq_group_metadata_list)
  247. else:
  248. (input_tokens, input_positions,
  249. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  250. seq_lens = []
  251. sampling_metadata = SamplingMetadata.prepare(
  252. seq_group_metadata_list,
  253. seq_lens,
  254. # query_lens is not needed if chunked prefill is not
  255. # supported. Since CPU worker doesn't support chunked prefill
  256. # just use seq_lens instead.
  257. seq_lens,
  258. self.device,
  259. pin_memory=False)
  260. # Broadcast the metadata.
  261. metadata_dict = {
  262. "input_tokens": input_tokens,
  263. "input_positions": input_positions,
  264. "selected_token_indices":
  265. sampling_metadata.selected_token_indices,
  266. }
  267. metadata_dict.update(attn_metadata.asdict_zerocopy())
  268. broadcast_tensor_dict(metadata_dict, src=0)
  269. else:
  270. metadata_dict = broadcast_tensor_dict(src=0)
  271. input_tokens = metadata_dict.pop("input_tokens")
  272. input_positions = metadata_dict.pop("input_positions")
  273. selected_token_indices = metadata_dict.pop(
  274. "selected_token_indices")
  275. attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
  276. sampling_metadata = SamplingMetadata(
  277. seq_groups=None,
  278. seq_data=None,
  279. seq_lens=None,
  280. selected_token_indices=selected_token_indices,
  281. categorized_sample_indices=None,
  282. generators=None,
  283. )
  284. return (input_tokens, input_positions, attn_metadata,
  285. sampling_metadata, multi_modal_kwargs)
  286. @torch.inference_mode()
  287. def execute_model(
  288. self,
  289. seq_group_metadata_list: List[SequenceGroupMetadata],
  290. kv_caches: List[torch.Tensor],
  291. ) -> Optional[SamplerOutput]:
  292. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  293. multi_modal_input
  294. ) = self.prepare_input_tensors(seq_group_metadata_list)
  295. model_executable = self.model
  296. execute_model_kwargs = {
  297. "input_ids": input_tokens,
  298. "positions": input_positions,
  299. "kv_caches": kv_caches,
  300. "attn_metadata": attn_metadata,
  301. }
  302. if self.vision_language_config and multi_modal_input is not None:
  303. execute_model_kwargs.update(multi_modal_input)
  304. hidden_states = model_executable(**execute_model_kwargs)
  305. # Compute the logits.
  306. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  307. # Only perform sampling in the driver worker.
  308. if not self.is_driver_worker:
  309. return None
  310. # Sample the next token.
  311. output = self.model.sample(
  312. logits=logits,
  313. sampling_metadata=sampling_metadata,
  314. )
  315. return output