cpu_model_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. from aphrodite.attention import AttentionMetadata, get_attn_backend
  5. from aphrodite.common.config import (DeviceConfig, LoadConfig, LoRAConfig,
  6. ModelConfig, ParallelConfig,
  7. SchedulerConfig, VisionLanguageConfig)
  8. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  9. from aphrodite.common.utils import make_tensor_with_pad
  10. from aphrodite.distributed import broadcast_tensor_dict
  11. from aphrodite.modeling import SamplingMetadata
  12. from aphrodite.modeling.model_loader import get_model
  13. _PAD_SLOT_ID = -1
  14. class CPUModelRunner:
  15. def __init__(
  16. self,
  17. model_config: ModelConfig,
  18. parallel_config: ParallelConfig,
  19. scheduler_config: SchedulerConfig,
  20. device_config: DeviceConfig,
  21. load_config: LoadConfig,
  22. lora_config: Optional[LoRAConfig],
  23. vision_language_config: Optional[VisionLanguageConfig],
  24. kv_cache_dtype: Optional[str] = "auto",
  25. is_driver_worker: bool = False,
  26. *args,
  27. **kwargs,
  28. ):
  29. self.model_config = model_config
  30. self.parallel_config = parallel_config
  31. self.scheduler_config = scheduler_config
  32. # Currently, CPU worker doesn't support chunked prefill.
  33. assert self.scheduler_config.chunked_prefill_enabled is False
  34. self.lora_config = lora_config
  35. self.vision_language_config = vision_language_config
  36. self.load_config = load_config
  37. self.is_driver_worker = is_driver_worker
  38. # model_config can be None in tests/samplers/test_sampler.py.
  39. # FIXME: This is a hack to make the tests work. Refactor this.
  40. self.sliding_window = (model_config.get_sliding_window()
  41. if model_config is not None else None)
  42. self.device_config = (device_config
  43. if device_config is not None else DeviceConfig())
  44. self.device = self.device_config.device
  45. self.kv_cache_dtype = kv_cache_dtype
  46. self.attn_backend = get_attn_backend(
  47. self.model_config.dtype if model_config is not None else None)
  48. # Lazy initialization.
  49. self.model: nn.Module # Set after init_Model
  50. self.block_size: int # Set after initial profiling.
  51. def load_model(self) -> None:
  52. self.model = get_model(
  53. model_config=self.model_config,
  54. load_config=self.load_config,
  55. device_config=self.device_config,
  56. vision_language_config=self.vision_language_config,
  57. lora_config=self.lora_config,
  58. parallel_config=self.parallel_config,
  59. scheduler_config=self.scheduler_config)
  60. def _prepare_prompt(
  61. self,
  62. seq_group_metadata_list: List[SequenceGroupMetadata],
  63. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  64. Optional[torch.Tensor]]:
  65. assert len(seq_group_metadata_list) > 0
  66. input_tokens: List[int] = []
  67. input_positions: List[int] = []
  68. slot_mapping: List[int] = []
  69. prompt_lens: List[int] = []
  70. multi_modal_input_list: List[torch.Tensor] = []
  71. for seq_group_metadata in seq_group_metadata_list:
  72. assert seq_group_metadata.is_prompt
  73. seq_ids = list(seq_group_metadata.seq_data.keys())
  74. assert len(seq_ids) == 1
  75. seq_id = seq_ids[0]
  76. seq_data = seq_group_metadata.seq_data[seq_id]
  77. prompt_tokens = seq_data.get_token_ids()
  78. computed_len = seq_data.get_num_computed_tokens()
  79. prompt_len = len(prompt_tokens)
  80. prompt_lens.append(prompt_len) # Prompt token num
  81. input_tokens.extend(prompt_tokens) # Token ids
  82. # Token position ids
  83. # NOTE: Here we assume that the first token in the prompt
  84. # is always the first token in the sequence.
  85. input_positions.extend(list(range(computed_len, prompt_len)))
  86. if seq_group_metadata.multi_modal_data:
  87. multi_modal_input_list.append(
  88. seq_group_metadata.multi_modal_data.data)
  89. # Compute the slot mapping.
  90. block_table = seq_group_metadata.block_tables[seq_id]
  91. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  92. # where start_idx is max(0, prompt_len - sliding_window).
  93. # For example, if the prompt len is 10, sliding window is 8, and
  94. # block size is 4, the first two tokens are masked and the slot
  95. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  96. start_idx = 0
  97. if self.sliding_window is not None:
  98. start_idx = max(0, prompt_len - self.sliding_window)
  99. for i in range(computed_len, prompt_len):
  100. if i < start_idx:
  101. slot_mapping.append(_PAD_SLOT_ID)
  102. continue
  103. block_number = block_table[i //
  104. self.block_size] # type: ignore
  105. block_offset = i % self.block_size # type: ignore
  106. slot = block_number * self.block_size + block_offset
  107. slot_mapping.append(slot)
  108. if multi_modal_input_list:
  109. assert self.vision_language_config, (
  110. "Multi-modal inputs are only supported by "
  111. "vision language models.")
  112. multi_modal_input = torch.cat(multi_modal_input_list,
  113. dim=0).to(self.device)
  114. else:
  115. multi_modal_input = None
  116. num_prompt_tokens = len(input_tokens)
  117. input_tokens = torch.tensor(input_tokens,
  118. dtype=torch.long,
  119. device=self.device) # type: ignore
  120. input_positions = torch.tensor(input_positions,
  121. dtype=torch.long,
  122. device=self.device) # type: ignore
  123. slot_mapping = torch.tensor(slot_mapping,
  124. dtype=torch.long,
  125. device=self.device) # type: ignore
  126. attn_metadata = self.attn_backend.make_metadata(
  127. is_prompt=True,
  128. prompt_lens=prompt_lens,
  129. num_prefills=len(prompt_lens),
  130. num_prefill_tokens=num_prompt_tokens,
  131. num_decode_tokens=0,
  132. prefill_metadata=None,
  133. decode_metadata=None,
  134. max_context_len=None,
  135. context_lens=None,
  136. block_tables=torch.tensor([]),
  137. slot_mapping=slot_mapping,
  138. kv_cache_dtype=self.kv_cache_dtype,
  139. )
  140. return (input_tokens, input_positions, attn_metadata, prompt_lens,
  141. multi_modal_input)
  142. def _prepare_decode(
  143. self,
  144. seq_group_metadata_list: List[SequenceGroupMetadata],
  145. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  146. assert len(seq_group_metadata_list) > 0
  147. input_tokens: List[int] = []
  148. input_positions: List[int] = []
  149. slot_mapping: List[int] = []
  150. context_lens: List[int] = []
  151. block_tables: List[List[int]] = []
  152. for seq_group_metadata in seq_group_metadata_list:
  153. assert not seq_group_metadata.is_prompt
  154. assert seq_group_metadata.token_chunk_size == 1
  155. seq_ids = list(seq_group_metadata.seq_data.keys())
  156. for seq_id in seq_ids:
  157. seq_data = seq_group_metadata.seq_data[seq_id]
  158. generation_token = seq_data.get_last_token_id()
  159. input_tokens.append(generation_token)
  160. seq_len = seq_data.get_len()
  161. position = seq_len - 1
  162. input_positions.append(position)
  163. context_len = seq_len if self.sliding_window is None else min(
  164. seq_len, self.sliding_window)
  165. context_lens.append(context_len)
  166. block_table = seq_group_metadata.block_tables[seq_id]
  167. block_number = block_table[position // self.block_size]
  168. block_offset = position % self.block_size
  169. slot = block_number * self.block_size + block_offset
  170. slot_mapping.append(slot)
  171. if self.sliding_window is not None:
  172. sliding_window_blocks = (self.sliding_window //
  173. self.block_size)
  174. block_table = block_table[-sliding_window_blocks:]
  175. block_tables.append(block_table)
  176. max_context_len = max(context_lens)
  177. input_tokens = torch.tensor(input_tokens,
  178. dtype=torch.long,
  179. device=self.device)
  180. input_positions = torch.tensor(input_positions,
  181. dtype=torch.long,
  182. device=self.device)
  183. slot_mapping = torch.tensor(slot_mapping,
  184. dtype=torch.long,
  185. device=self.device)
  186. context_lens = torch.tensor(context_lens,
  187. dtype=torch.int,
  188. device=self.device)
  189. max_block_table_len = max(
  190. len(block_table) for block_table in block_tables)
  191. block_tables = make_tensor_with_pad(
  192. block_tables,
  193. max_len=max_block_table_len,
  194. pad=0,
  195. dtype=torch.int,
  196. device=self.device,
  197. )
  198. attn_metadata = self.attn_backend.make_metadata(
  199. is_prompt=False,
  200. slot_mapping=slot_mapping,
  201. prompt_lens=None,
  202. num_prefill_tokens=0,
  203. num_decode_tokens=len(input_tokens),
  204. max_context_len=max_context_len,
  205. num_prefills=0,
  206. prefill_metadata=None,
  207. decode_metadata=None,
  208. context_lens=context_lens,
  209. block_tables=block_tables,
  210. kv_cache_dtype=self.kv_cache_dtype,
  211. )
  212. return (
  213. input_tokens,
  214. input_positions,
  215. attn_metadata,
  216. )
  217. def prepare_input_tensors(
  218. self,
  219. seq_group_metadata_list: List[SequenceGroupMetadata],
  220. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  221. Optional[torch.Tensor]]:
  222. multi_modal_input = None
  223. if self.is_driver_worker:
  224. # NOTE: We assume that all sequences in the group are all prompts or
  225. # all decodes.
  226. is_prompt = seq_group_metadata_list[0].is_prompt
  227. # Prepare input tensors.
  228. if is_prompt:
  229. (input_tokens, input_positions, attn_metadata, prompt_lens,
  230. multi_modal_input
  231. ) = self._prepare_prompt(seq_group_metadata_list)
  232. else:
  233. (input_tokens, input_positions,
  234. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  235. prompt_lens = []
  236. sampling_metadata = SamplingMetadata.prepare(
  237. seq_group_metadata_list,
  238. prompt_lens,
  239. # subquery_lens is not needed if chunked prefill is not
  240. # supported. Since CPU worker doesn't support chunked prefill
  241. # just use prompt_lens instead.
  242. prompt_lens,
  243. self.device,
  244. pin_memory=False)
  245. # Broadcast the metadata.
  246. metadata_dict = {
  247. "input_tokens": input_tokens,
  248. "input_positions": input_positions,
  249. "selected_token_indices":
  250. sampling_metadata.selected_token_indices,
  251. }
  252. metadata_dict.update(attn_metadata.asdict_zerocopy())
  253. broadcast_tensor_dict(metadata_dict, src=0)
  254. else:
  255. metadata_dict = broadcast_tensor_dict(src=0)
  256. input_tokens = metadata_dict.pop("input_tokens")
  257. input_positions = metadata_dict.pop("input_positions")
  258. selected_token_indices = metadata_dict.pop(
  259. "selected_token_indices")
  260. attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
  261. sampling_metadata = SamplingMetadata(
  262. seq_groups=None,
  263. seq_data=None,
  264. prompt_lens=None,
  265. selected_token_indices=selected_token_indices,
  266. categorized_sample_indices=None,
  267. generators=None,
  268. )
  269. return (input_tokens, input_positions, attn_metadata,
  270. sampling_metadata, multi_modal_input)
  271. @torch.inference_mode()
  272. def execute_model(
  273. self,
  274. seq_group_metadata_list: List[SequenceGroupMetadata],
  275. kv_caches: List[torch.Tensor],
  276. ) -> Optional[SamplerOutput]:
  277. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  278. multi_modal_input
  279. ) = self.prepare_input_tensors(seq_group_metadata_list)
  280. model_executable = self.model
  281. execute_model_kwargs = {
  282. "input_ids": input_tokens,
  283. "positions": input_positions,
  284. "kv_caches": kv_caches,
  285. "attn_metadata": attn_metadata,
  286. }
  287. if self.vision_language_config:
  288. execute_model_kwargs.update({"image_input": multi_modal_input})
  289. hidden_states = model_executable(**execute_model_kwargs)
  290. # Compute the logits.
  291. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  292. # Only perform sampling in the driver worker.
  293. if not self.is_driver_worker:
  294. return None
  295. # Sample the next token.
  296. output = self.model.sample(
  297. logits=logits,
  298. sampling_metadata=sampling_metadata,
  299. )
  300. return output