cpu_model_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
  3. import torch
  4. from torch import nn
  5. from aphrodite.attention import AttentionMetadata, get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. PromptAdapterConfig, SchedulerConfig)
  9. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  10. SequenceGroupMetadata)
  11. from aphrodite.common.utils import make_tensor_with_pad
  12. from aphrodite.modeling import SamplingMetadata
  13. from aphrodite.modeling.model_loader import get_model
  14. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  15. MultiModalInputs)
  16. from aphrodite.task_handler.model_runner_base import (
  17. ModelRunnerBase, ModelRunnerInputBase,
  18. _add_attn_metadata_broadcastable_dict,
  19. _add_sampling_metadata_broadcastable_dict,
  20. _init_attn_metadata_from_tensor_dict,
  21. _init_sampling_metadata_from_tensor_dict)
  22. if TYPE_CHECKING:
  23. from aphrodite.attention.backends.abstract import AttentionBackend
  24. _PAD_SLOT_ID = -1
  25. @dataclass(frozen=True)
  26. class CPUModelInput(ModelRunnerInputBase):
  27. """
  28. Used by the CPUModelRunner.
  29. """
  30. input_tokens: Optional[torch.Tensor] = None
  31. input_positions: Optional[torch.Tensor] = None
  32. attn_metadata: Optional["AttentionMetadata"] = None
  33. sampling_metadata: Optional["SamplingMetadata"] = None
  34. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  35. virtual_engine: Optional[int] = None
  36. def as_broadcastable_tensor_dict(
  37. self) -> Dict[str, Union[int, torch.Tensor]]:
  38. tensor_dict = {
  39. "input_tokens": self.input_tokens,
  40. "input_positions": self.input_positions,
  41. "multi_modal_kwargs": self.multi_modal_kwargs,
  42. }
  43. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  44. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  45. self.sampling_metadata)
  46. return tensor_dict
  47. @classmethod
  48. def from_broadcasted_tensor_dict(
  49. cls: Type["CPUModelInput"],
  50. tensor_dict: Dict[str, Any],
  51. attn_backend: Optional["AttentionBackend"] = None
  52. ) -> "CPUModelInput":
  53. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  54. if attn_backend is not None:
  55. tensor_dict = _init_attn_metadata_from_tensor_dict(
  56. attn_backend, tensor_dict)
  57. return cls(**tensor_dict)
  58. class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
  59. def __init__(
  60. self,
  61. model_config: ModelConfig,
  62. parallel_config: ParallelConfig,
  63. scheduler_config: SchedulerConfig,
  64. device_config: DeviceConfig,
  65. cache_config: CacheConfig,
  66. load_config: LoadConfig,
  67. lora_config: Optional[LoRAConfig],
  68. kv_cache_dtype: Optional[str] = "auto",
  69. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  70. is_driver_worker: bool = False,
  71. *args,
  72. **kwargs,
  73. ):
  74. self.model_config = model_config
  75. self.parallel_config = parallel_config
  76. self.scheduler_config = scheduler_config
  77. # Currently, CPU worker doesn't support chunked prefill.
  78. assert self.scheduler_config.chunked_prefill_enabled is False
  79. self.device_config = device_config
  80. self.cache_config = cache_config
  81. self.lora_config = lora_config
  82. self.prompt_adapter_config = prompt_adapter_config
  83. self.load_config = load_config
  84. self.is_driver_worker = is_driver_worker
  85. self.device = self.device_config.device
  86. self.kv_cache_dtype = kv_cache_dtype
  87. self.sliding_window = model_config.get_sliding_window()
  88. self.block_size = cache_config.block_size
  89. self.attn_backend = get_attn_backend(
  90. self.model_config.get_head_size(),
  91. self.model_config.get_sliding_window(),
  92. self.model_config.dtype,
  93. self.kv_cache_dtype,
  94. self.block_size,
  95. self.model_config.is_attention_free(),
  96. )
  97. # Multi-modal data support
  98. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  99. .create_input_mapper(self.model_config)
  100. # Lazy initialization.
  101. self.model: nn.Module # Set after init_Model
  102. def load_model(self) -> None:
  103. self.model = get_model(model_config=self.model_config,
  104. load_config=self.load_config,
  105. device_config=self.device_config,
  106. lora_config=self.lora_config,
  107. parallel_config=self.parallel_config,
  108. scheduler_config=self.scheduler_config,
  109. cache_config=self.cache_config)
  110. def _prepare_prompt(
  111. self,
  112. seq_group_metadata_list: List[SequenceGroupMetadata],
  113. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  114. BatchedTensorInputs]:
  115. assert len(seq_group_metadata_list) > 0
  116. input_tokens: List[int] = []
  117. input_positions: List[int] = []
  118. slot_mapping: List[int] = []
  119. seq_lens: List[int] = []
  120. multi_modal_inputs_list: List[MultiModalInputs] = []
  121. for seq_group_metadata in seq_group_metadata_list:
  122. assert seq_group_metadata.is_prompt
  123. seq_ids = list(seq_group_metadata.seq_data.keys())
  124. assert len(seq_ids) == 1
  125. seq_id = seq_ids[0]
  126. seq_data = seq_group_metadata.seq_data[seq_id]
  127. prompt_tokens = seq_data.get_token_ids()
  128. computed_len = seq_data.get_num_computed_tokens()
  129. seq_len = len(prompt_tokens)
  130. seq_lens.append(seq_len) # Prompt token num
  131. input_tokens.extend(prompt_tokens) # Token ids
  132. # Token position ids
  133. # NOTE: Here we assume that the first token in the prompt
  134. # is always the first token in the sequence.
  135. input_positions.extend(list(range(computed_len, seq_len)))
  136. mm_data = seq_group_metadata.multi_modal_data
  137. if mm_data:
  138. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  139. multi_modal_inputs_list.append(mm_kwargs)
  140. # Compute the slot mapping.
  141. block_table = seq_group_metadata.block_tables[seq_id]
  142. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  143. # where start_idx is max(0, seq_len - sliding_window).
  144. # For example, if the prompt len is 10, sliding window is 8, and
  145. # block size is 4, the first two tokens are masked and the slot
  146. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  147. start_idx = 0
  148. if self.sliding_window is not None:
  149. start_idx = max(0, seq_len - self.sliding_window)
  150. for i in range(computed_len, seq_len):
  151. if i < start_idx:
  152. slot_mapping.append(_PAD_SLOT_ID)
  153. continue
  154. block_number = block_table[i //
  155. self.block_size] # type: ignore
  156. block_offset = i % self.block_size # type: ignore
  157. slot = block_number * self.block_size + block_offset
  158. slot_mapping.append(slot)
  159. num_prompt_tokens = len(input_tokens)
  160. input_tokens = torch.tensor(input_tokens,
  161. dtype=torch.long,
  162. device=self.device) # type: ignore
  163. input_positions = torch.tensor(input_positions,
  164. dtype=torch.long,
  165. device=self.device) # type: ignore
  166. slot_mapping = torch.tensor(slot_mapping,
  167. dtype=torch.long,
  168. device=self.device) # type: ignore
  169. attn_metadata = self.attn_backend.make_metadata(
  170. is_prompt=True,
  171. seq_lens=seq_lens,
  172. seq_lens_tensor=torch.tensor([]),
  173. max_decode_seq_len=0,
  174. num_prefills=len(seq_lens),
  175. num_prefill_tokens=num_prompt_tokens,
  176. num_decode_tokens=0,
  177. block_tables=torch.tensor([]),
  178. slot_mapping=slot_mapping,
  179. )
  180. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  181. return (input_tokens, input_positions, attn_metadata, seq_lens,
  182. multi_modal_kwargs)
  183. def _prepare_decode(
  184. self,
  185. seq_group_metadata_list: List[SequenceGroupMetadata],
  186. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  187. assert len(seq_group_metadata_list) > 0
  188. input_tokens: List[int] = []
  189. input_positions: List[int] = []
  190. slot_mapping: List[int] = []
  191. seq_lens: List[int] = []
  192. block_tables: List[List[int]] = []
  193. for seq_group_metadata in seq_group_metadata_list:
  194. assert not seq_group_metadata.is_prompt
  195. assert seq_group_metadata.token_chunk_size == 1
  196. seq_ids = list(seq_group_metadata.seq_data.keys())
  197. for seq_id in seq_ids:
  198. seq_data = seq_group_metadata.seq_data[seq_id]
  199. generation_token = seq_data.get_last_token_id()
  200. input_tokens.append(generation_token)
  201. seq_len = seq_data.get_len()
  202. position = seq_len - 1
  203. input_positions.append(position)
  204. seq_len = seq_len if self.sliding_window is None else min(
  205. seq_len, self.sliding_window)
  206. seq_lens.append(seq_len)
  207. block_table = seq_group_metadata.block_tables[seq_id]
  208. block_number = block_table[position // self.block_size]
  209. block_offset = position % self.block_size
  210. slot = block_number * self.block_size + block_offset
  211. slot_mapping.append(slot)
  212. if self.sliding_window is not None:
  213. sliding_window_blocks = (self.sliding_window //
  214. self.block_size)
  215. block_table = block_table[-sliding_window_blocks:]
  216. block_tables.append(block_table)
  217. max_decode_seq_len = max(seq_lens)
  218. input_tokens = torch.tensor(input_tokens,
  219. dtype=torch.long,
  220. device=self.device)
  221. input_positions = torch.tensor(input_positions,
  222. dtype=torch.long,
  223. device=self.device)
  224. slot_mapping = torch.tensor(slot_mapping,
  225. dtype=torch.long,
  226. device=self.device)
  227. seq_lens_tensor = torch.tensor(seq_lens,
  228. dtype=torch.int,
  229. device=self.device)
  230. block_tables = make_tensor_with_pad(
  231. block_tables,
  232. pad=0,
  233. dtype=torch.int,
  234. device=self.device,
  235. )
  236. attn_metadata = self.attn_backend.make_metadata(
  237. is_prompt=False,
  238. slot_mapping=slot_mapping,
  239. seq_lens=seq_lens,
  240. seq_lens_tensor=seq_lens_tensor,
  241. max_decode_seq_len=max_decode_seq_len,
  242. num_prefill_tokens=0,
  243. num_decode_tokens=len(input_tokens),
  244. num_prefills=0,
  245. block_tables=block_tables,
  246. )
  247. return (
  248. input_tokens,
  249. input_positions,
  250. attn_metadata,
  251. )
  252. def make_model_input_from_broadcasted_tensor_dict(
  253. self,
  254. tensor_dict: Dict[str, Any],
  255. ) -> CPUModelInput:
  256. return CPUModelInput.from_broadcasted_tensor_dict(
  257. tensor_dict,
  258. attn_backend=self.attn_backend,
  259. )
  260. def prepare_model_input(
  261. self,
  262. seq_group_metadata_list: List[SequenceGroupMetadata],
  263. virtual_engine: int = 0,
  264. finished_requests_ids: Optional[List[str]] = None
  265. ) -> CPUModelInput:
  266. multi_modal_kwargs = None
  267. # NOTE: We assume that all sequences in the group are all prompts or
  268. # all decodes.
  269. is_prompt = seq_group_metadata_list[0].is_prompt
  270. # Prepare input tensors.
  271. if is_prompt:
  272. (input_tokens, input_positions, attn_metadata, seq_lens,
  273. multi_modal_kwargs
  274. ) = self._prepare_prompt(seq_group_metadata_list)
  275. else:
  276. (input_tokens, input_positions,
  277. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  278. seq_lens = []
  279. sampling_metadata = SamplingMetadata.prepare(
  280. seq_group_metadata_list,
  281. seq_lens,
  282. # query_lens is not needed if chunked prefill is not
  283. # supported. Since CPU worker doesn't support chunked prefill
  284. # just use seq_lens instead.
  285. seq_lens,
  286. self.device,
  287. pin_memory=False,
  288. generators=self.get_generators(finished_requests_ids))
  289. return CPUModelInput(
  290. input_tokens=input_tokens,
  291. input_positions=input_positions,
  292. attn_metadata=attn_metadata,
  293. sampling_metadata=sampling_metadata,
  294. multi_modal_kwargs=multi_modal_kwargs,
  295. )
  296. @torch.no_grad()
  297. def execute_model(
  298. self,
  299. model_input: CPUModelInput,
  300. kv_caches: List[torch.Tensor],
  301. intermediate_tensors: Optional[IntermediateTensors] = None,
  302. num_steps: int = 1,
  303. ) -> Optional[List[SamplerOutput]]:
  304. if num_steps > 1:
  305. raise ValueError(
  306. "CPU worker does not support multi-step execution.")
  307. model_executable = self.model
  308. execute_model_kwargs = {
  309. "input_ids":
  310. model_input.input_tokens,
  311. "positions":
  312. model_input.input_positions,
  313. "kv_caches":
  314. kv_caches,
  315. "attn_metadata":
  316. model_input.attn_metadata,
  317. **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
  318. device=self.device),
  319. }
  320. hidden_states = model_executable(**execute_model_kwargs)
  321. # Compute the logits.
  322. logits = self.model.compute_logits(hidden_states,
  323. model_input.sampling_metadata)
  324. # Only perform sampling in the driver worker.
  325. if not self.is_driver_worker:
  326. return []
  327. # Sample the next token.
  328. output = self.model.sample(
  329. logits=logits,
  330. sampling_metadata=model_input.sampling_metadata,
  331. )
  332. return [output]