cpu_model_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. from dataclasses import dataclass
  2. from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
  3. Type, Union)
  4. import torch
  5. from torch import nn
  6. from aphrodite.attention import AttentionMetadata, get_attn_backend
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. LoRAConfig, ModelConfig, MultiModalConfig,
  9. ParallelConfig, PromptAdapterConfig,
  10. SchedulerConfig)
  11. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  12. SequenceGroupMetadata)
  13. from aphrodite.common.utils import make_tensor_with_pad
  14. from aphrodite.modeling import SamplingMetadata
  15. from aphrodite.modeling.model_loader import get_model
  16. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
  17. MultiModalInputs)
  18. from aphrodite.task_handler.model_runner_base import (
  19. ModelRunnerBase, ModelRunnerInputBase,
  20. _add_attn_metadata_broadcastable_dict,
  21. _add_sampling_metadata_broadcastable_dict,
  22. _init_attn_metadata_from_tensor_dict,
  23. _init_sampling_metadata_from_tensor_dict)
  24. if TYPE_CHECKING:
  25. from aphrodite.attention.backends.abstract import AttentionBackend
  26. _PAD_SLOT_ID = -1
  27. @dataclass(frozen=True)
  28. class CPUModelInput(ModelRunnerInputBase):
  29. """
  30. Used by the CPUModelRunner.
  31. """
  32. input_tokens: Optional[torch.Tensor] = None
  33. input_positions: Optional[torch.Tensor] = None
  34. attn_metadata: Optional["AttentionMetadata"] = None
  35. sampling_metadata: Optional["SamplingMetadata"] = None
  36. multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
  37. def as_broadcastable_tensor_dict(
  38. self) -> Dict[str, Union[int, torch.Tensor]]:
  39. tensor_dict = {
  40. "input_tokens": self.input_tokens,
  41. "input_positions": self.input_positions,
  42. "multi_modal_kwargs": self.multi_modal_kwargs,
  43. }
  44. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  45. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  46. self.sampling_metadata)
  47. return tensor_dict
  48. @classmethod
  49. def from_broadcasted_tensor_dict(
  50. cls: Type["CPUModelInput"],
  51. tensor_dict: Dict[str, Any],
  52. attn_backend: Optional["AttentionBackend"] = None
  53. ) -> "CPUModelInput":
  54. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  55. if attn_backend is not None:
  56. tensor_dict = _init_attn_metadata_from_tensor_dict(
  57. attn_backend, tensor_dict)
  58. return cls(**tensor_dict)
  59. class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
  60. def __init__(
  61. self,
  62. model_config: ModelConfig,
  63. parallel_config: ParallelConfig,
  64. scheduler_config: SchedulerConfig,
  65. device_config: DeviceConfig,
  66. cache_config: CacheConfig,
  67. load_config: LoadConfig,
  68. lora_config: Optional[LoRAConfig],
  69. multimodal_config: Optional[MultiModalConfig],
  70. kv_cache_dtype: Optional[str] = "auto",
  71. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  72. is_driver_worker: bool = False,
  73. *args,
  74. **kwargs,
  75. ):
  76. self.model_config = model_config
  77. self.parallel_config = parallel_config
  78. self.scheduler_config = scheduler_config
  79. # Currently, CPU worker doesn't support chunked prefill.
  80. assert self.scheduler_config.chunked_prefill_enabled is False
  81. self.device_config = device_config
  82. self.cache_config = cache_config
  83. self.lora_config = lora_config
  84. self.multimodal_config = multimodal_config
  85. self.prompt_adapter_config = prompt_adapter_config
  86. self.load_config = load_config
  87. self.is_driver_worker = is_driver_worker
  88. self.device = self.device_config.device
  89. self.kv_cache_dtype = kv_cache_dtype
  90. self.sliding_window = model_config.get_sliding_window()
  91. self.block_size = cache_config.block_size
  92. self.attn_backend = get_attn_backend(
  93. self.model_config.get_num_attention_heads(self.parallel_config),
  94. self.model_config.get_head_size(),
  95. self.model_config.get_num_kv_heads(self.parallel_config),
  96. self.model_config.get_sliding_window(),
  97. self.model_config.dtype,
  98. self.kv_cache_dtype,
  99. self.block_size,
  100. )
  101. # Multi-modal data support
  102. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  103. .create_input_mapper(self.model_config)
  104. # Lazy initialization.
  105. self.model: nn.Module # Set after init_Model
  106. def load_model(self) -> None:
  107. self.model = get_model(model_config=self.model_config,
  108. load_config=self.load_config,
  109. device_config=self.device_config,
  110. multimodal_config=self.multimodal_config,
  111. lora_config=self.lora_config,
  112. parallel_config=self.parallel_config,
  113. scheduler_config=self.scheduler_config,
  114. cache_config=self.cache_config)
  115. def _prepare_prompt(
  116. self,
  117. seq_group_metadata_list: List[SequenceGroupMetadata],
  118. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  119. Mapping[str, BatchedTensors]]:
  120. assert len(seq_group_metadata_list) > 0
  121. input_tokens: List[int] = []
  122. input_positions: List[int] = []
  123. slot_mapping: List[int] = []
  124. seq_lens: List[int] = []
  125. multi_modal_inputs_list: List[MultiModalInputs] = []
  126. for seq_group_metadata in seq_group_metadata_list:
  127. assert seq_group_metadata.is_prompt
  128. seq_ids = list(seq_group_metadata.seq_data.keys())
  129. assert len(seq_ids) == 1
  130. seq_id = seq_ids[0]
  131. seq_data = seq_group_metadata.seq_data[seq_id]
  132. prompt_tokens = seq_data.get_token_ids()
  133. computed_len = seq_data.get_num_computed_tokens()
  134. seq_len = len(prompt_tokens)
  135. seq_lens.append(seq_len) # Prompt token num
  136. input_tokens.extend(prompt_tokens) # Token ids
  137. # Token position ids
  138. # NOTE: Here we assume that the first token in the prompt
  139. # is always the first token in the sequence.
  140. input_positions.extend(list(range(computed_len, seq_len)))
  141. mm_data = seq_group_metadata.multi_modal_data
  142. if mm_data:
  143. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  144. multi_modal_inputs_list.append(mm_kwargs)
  145. # Compute the slot mapping.
  146. block_table = seq_group_metadata.block_tables[seq_id]
  147. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  148. # where start_idx is max(0, seq_len - sliding_window).
  149. # For example, if the prompt len is 10, sliding window is 8, and
  150. # block size is 4, the first two tokens are masked and the slot
  151. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  152. start_idx = 0
  153. if self.sliding_window is not None:
  154. start_idx = max(0, seq_len - self.sliding_window)
  155. for i in range(computed_len, seq_len):
  156. if i < start_idx:
  157. slot_mapping.append(_PAD_SLOT_ID)
  158. continue
  159. block_number = block_table[i //
  160. self.block_size] # type: ignore
  161. block_offset = i % self.block_size # type: ignore
  162. slot = block_number * self.block_size + block_offset
  163. slot_mapping.append(slot)
  164. num_prompt_tokens = len(input_tokens)
  165. input_tokens = torch.tensor(input_tokens,
  166. dtype=torch.long,
  167. device=self.device) # type: ignore
  168. input_positions = torch.tensor(input_positions,
  169. dtype=torch.long,
  170. device=self.device) # type: ignore
  171. slot_mapping = torch.tensor(slot_mapping,
  172. dtype=torch.long,
  173. device=self.device) # type: ignore
  174. attn_metadata = self.attn_backend.make_metadata(
  175. is_prompt=True,
  176. seq_lens=seq_lens,
  177. seq_lens_tensor=None,
  178. max_decode_seq_len=None,
  179. num_prefills=len(seq_lens),
  180. num_prefill_tokens=num_prompt_tokens,
  181. num_decode_tokens=0,
  182. block_tables=torch.tensor([]),
  183. slot_mapping=slot_mapping,
  184. )
  185. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
  186. device=self.device)
  187. return (input_tokens, input_positions, attn_metadata, seq_lens,
  188. multi_modal_kwargs)
  189. def _prepare_decode(
  190. self,
  191. seq_group_metadata_list: List[SequenceGroupMetadata],
  192. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  193. assert len(seq_group_metadata_list) > 0
  194. input_tokens: List[int] = []
  195. input_positions: List[int] = []
  196. slot_mapping: List[int] = []
  197. seq_lens: List[int] = []
  198. block_tables: List[List[int]] = []
  199. for seq_group_metadata in seq_group_metadata_list:
  200. assert not seq_group_metadata.is_prompt
  201. assert seq_group_metadata.token_chunk_size == 1
  202. seq_ids = list(seq_group_metadata.seq_data.keys())
  203. for seq_id in seq_ids:
  204. seq_data = seq_group_metadata.seq_data[seq_id]
  205. generation_token = seq_data.get_last_token_id()
  206. input_tokens.append(generation_token)
  207. seq_len = seq_data.get_len()
  208. position = seq_len - 1
  209. input_positions.append(position)
  210. seq_len = seq_len if self.sliding_window is None else min(
  211. seq_len, self.sliding_window)
  212. seq_lens.append(seq_len)
  213. block_table = seq_group_metadata.block_tables[seq_id]
  214. block_number = block_table[position // self.block_size]
  215. block_offset = position % self.block_size
  216. slot = block_number * self.block_size + block_offset
  217. slot_mapping.append(slot)
  218. if self.sliding_window is not None:
  219. sliding_window_blocks = (self.sliding_window //
  220. self.block_size)
  221. block_table = block_table[-sliding_window_blocks:]
  222. block_tables.append(block_table)
  223. max_decode_seq_len = max(seq_lens)
  224. input_tokens = torch.tensor(input_tokens,
  225. dtype=torch.long,
  226. device=self.device)
  227. input_positions = torch.tensor(input_positions,
  228. dtype=torch.long,
  229. device=self.device)
  230. slot_mapping = torch.tensor(slot_mapping,
  231. dtype=torch.long,
  232. device=self.device)
  233. seq_lens_tensor = torch.tensor(seq_lens,
  234. dtype=torch.int,
  235. device=self.device)
  236. max_block_table_len = max(
  237. len(block_table) for block_table in block_tables)
  238. block_tables = make_tensor_with_pad(
  239. block_tables,
  240. max_len=max_block_table_len,
  241. pad=0,
  242. dtype=torch.int,
  243. device=self.device,
  244. )
  245. attn_metadata = self.attn_backend.make_metadata(
  246. is_prompt=False,
  247. slot_mapping=slot_mapping,
  248. seq_lens=seq_lens,
  249. seq_lens_tensor=seq_lens_tensor,
  250. max_decode_seq_len=max_decode_seq_len,
  251. num_prefill_tokens=0,
  252. num_decode_tokens=len(input_tokens),
  253. num_prefills=0,
  254. block_tables=block_tables,
  255. )
  256. return (
  257. input_tokens,
  258. input_positions,
  259. attn_metadata,
  260. )
  261. def make_model_input_from_broadcasted_tensor_dict(
  262. self,
  263. tensor_dict: Dict[str, Any],
  264. ) -> CPUModelInput:
  265. return CPUModelInput.from_broadcasted_tensor_dict(
  266. tensor_dict,
  267. attn_backend=self.attn_backend,
  268. )
  269. def prepare_model_input(
  270. self,
  271. seq_group_metadata_list: List[SequenceGroupMetadata],
  272. virtual_engine: int = 0,
  273. finished_requests_ids: Optional[List[str]] = None
  274. ) -> CPUModelInput:
  275. multi_modal_kwargs = None
  276. # NOTE: We assume that all sequences in the group are all prompts or
  277. # all decodes.
  278. is_prompt = seq_group_metadata_list[0].is_prompt
  279. # Prepare input tensors.
  280. if is_prompt:
  281. (input_tokens, input_positions, attn_metadata, seq_lens,
  282. multi_modal_kwargs
  283. ) = self._prepare_prompt(seq_group_metadata_list)
  284. else:
  285. (input_tokens, input_positions,
  286. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  287. seq_lens = []
  288. sampling_metadata = SamplingMetadata.prepare(
  289. seq_group_metadata_list,
  290. seq_lens,
  291. # query_lens is not needed if chunked prefill is not
  292. # supported. Since CPU worker doesn't support chunked prefill
  293. # just use seq_lens instead.
  294. seq_lens,
  295. self.device,
  296. pin_memory=False)
  297. return CPUModelInput(
  298. input_tokens=input_tokens,
  299. input_positions=input_positions,
  300. attn_metadata=attn_metadata,
  301. sampling_metadata=sampling_metadata,
  302. multi_modal_kwargs=multi_modal_kwargs,
  303. )
  304. @torch.inference_mode()
  305. def execute_model(
  306. self,
  307. model_input: CPUModelInput,
  308. kv_caches: List[torch.Tensor],
  309. intermediate_tensors: Optional[IntermediateTensors] = None,
  310. num_steps: int = 1,
  311. ) -> Optional[List[SamplerOutput]]:
  312. if num_steps > 1:
  313. raise ValueError(
  314. "CPU worker does not support multi-step execution.")
  315. model_executable = self.model
  316. execute_model_kwargs = {
  317. "input_ids": model_input.input_tokens,
  318. "positions": model_input.input_positions,
  319. "kv_caches": kv_caches,
  320. "attn_metadata": model_input.attn_metadata,
  321. **(model_input.multi_modal_kwargs or {}),
  322. }
  323. hidden_states = model_executable(**execute_model_kwargs)
  324. # Compute the logits.
  325. logits = self.model.compute_logits(hidden_states,
  326. model_input.sampling_metadata)
  327. # Only perform sampling in the driver worker.
  328. if not self.is_driver_worker:
  329. return []
  330. # Sample the next token.
  331. output = self.model.sample(
  332. logits=logits,
  333. sampling_metadata=model_input.sampling_metadata,
  334. )
  335. return [output]