cpu_model_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
  3. import torch
  4. from torch import nn
  5. from aphrodite.attention import AttentionMetadata, get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. PromptAdapterConfig, SchedulerConfig)
  9. from aphrodite.common.sequence import (IntermediateTensors,
  10. SequenceGroupMetadata)
  11. from aphrodite.common.utils import make_tensor_with_pad
  12. from aphrodite.modeling.layers.sampler import SamplerOutput
  13. from aphrodite.modeling.model_loader import get_model
  14. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  15. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  16. MultiModalInputs)
  17. from aphrodite.task_handler.model_runner_base import (
  18. ModelRunnerBase, ModelRunnerInputBase,
  19. _add_attn_metadata_broadcastable_dict,
  20. _add_sampling_metadata_broadcastable_dict,
  21. _init_attn_metadata_from_tensor_dict,
  22. _init_sampling_metadata_from_tensor_dict)
  23. if TYPE_CHECKING:
  24. from aphrodite.attention.backends.abstract import AttentionBackend
  25. _PAD_SLOT_ID = -1
  26. @dataclass(frozen=True)
  27. class CPUModelInput(ModelRunnerInputBase):
  28. """
  29. Used by the CPUModelRunner.
  30. """
  31. input_tokens: Optional[torch.Tensor] = None
  32. input_positions: Optional[torch.Tensor] = None
  33. attn_metadata: Optional["AttentionMetadata"] = None
  34. sampling_metadata: Optional["SamplingMetadata"] = None
  35. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  36. virtual_engine: Optional[int] = None
  37. def as_broadcastable_tensor_dict(
  38. self) -> Dict[str, Union[int, torch.Tensor]]:
  39. tensor_dict = {
  40. "input_tokens": self.input_tokens,
  41. "input_positions": self.input_positions,
  42. "multi_modal_kwargs": self.multi_modal_kwargs,
  43. }
  44. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  45. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  46. self.sampling_metadata)
  47. return tensor_dict
  48. @classmethod
  49. def from_broadcasted_tensor_dict(
  50. cls: Type["CPUModelInput"],
  51. tensor_dict: Dict[str, Any],
  52. attn_backend: Optional["AttentionBackend"] = None
  53. ) -> "CPUModelInput":
  54. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  55. if attn_backend is not None:
  56. tensor_dict = _init_attn_metadata_from_tensor_dict(
  57. attn_backend, tensor_dict)
  58. return cls(**tensor_dict)
  59. class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
  60. def __init__(
  61. self,
  62. model_config: ModelConfig,
  63. parallel_config: ParallelConfig,
  64. scheduler_config: SchedulerConfig,
  65. device_config: DeviceConfig,
  66. cache_config: CacheConfig,
  67. load_config: LoadConfig,
  68. lora_config: Optional[LoRAConfig],
  69. kv_cache_dtype: Optional[str] = "auto",
  70. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  71. is_driver_worker: bool = False,
  72. *args,
  73. **kwargs,
  74. ):
  75. self.model_config = model_config
  76. self.parallel_config = parallel_config
  77. self.scheduler_config = scheduler_config
  78. # Currently, CPU worker doesn't support chunked prefill.
  79. assert self.scheduler_config.chunked_prefill_enabled is False
  80. self.device_config = device_config
  81. self.cache_config = cache_config
  82. self.lora_config = lora_config
  83. self.prompt_adapter_config = prompt_adapter_config
  84. self.load_config = load_config
  85. self.is_driver_worker = is_driver_worker
  86. self.device = self.device_config.device
  87. self.kv_cache_dtype = kv_cache_dtype
  88. self.sliding_window = model_config.get_sliding_window()
  89. self.block_size = cache_config.block_size
  90. self.attn_backend = get_attn_backend(
  91. self.model_config.get_head_size(),
  92. self.model_config.get_sliding_window(),
  93. self.model_config.dtype,
  94. self.kv_cache_dtype,
  95. self.block_size,
  96. self.model_config.is_attention_free(),
  97. )
  98. # Multi-modal data support
  99. self.mm_registry = MULTIMODAL_REGISTRY
  100. self.multi_modal_input_mapper = self.mm_registry \
  101. .create_input_mapper(self.model_config)
  102. self.mm_registry.init_mm_limits_per_prompt(self.model_config)
  103. # Lazy initialization.
  104. self.model: nn.Module # Set after init_Model
  105. def load_model(self) -> None:
  106. self.model = get_model(model_config=self.model_config,
  107. load_config=self.load_config,
  108. device_config=self.device_config,
  109. lora_config=self.lora_config,
  110. parallel_config=self.parallel_config,
  111. scheduler_config=self.scheduler_config,
  112. cache_config=self.cache_config)
  113. def _prepare_prompt(
  114. self,
  115. seq_group_metadata_list: List[SequenceGroupMetadata],
  116. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  117. BatchedTensorInputs]:
  118. assert len(seq_group_metadata_list) > 0
  119. input_tokens: List[int] = []
  120. input_positions: List[int] = []
  121. slot_mapping: List[int] = []
  122. seq_lens: List[int] = []
  123. multi_modal_inputs_list: List[MultiModalInputs] = []
  124. for seq_group_metadata in seq_group_metadata_list:
  125. assert seq_group_metadata.is_prompt
  126. seq_ids = list(seq_group_metadata.seq_data.keys())
  127. assert len(seq_ids) == 1
  128. seq_id = seq_ids[0]
  129. seq_data = seq_group_metadata.seq_data[seq_id]
  130. prompt_tokens = seq_data.get_token_ids()
  131. computed_len = seq_data.get_num_computed_tokens()
  132. seq_len = len(prompt_tokens)
  133. seq_lens.append(seq_len) # Prompt token num
  134. input_tokens.extend(prompt_tokens) # Token ids
  135. # Token position ids
  136. # NOTE: Here we assume that the first token in the prompt
  137. # is always the first token in the sequence.
  138. input_positions.extend(list(range(computed_len, seq_len)))
  139. mm_data = seq_group_metadata.multi_modal_data
  140. if mm_data:
  141. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  142. multi_modal_inputs_list.append(mm_kwargs)
  143. # Compute the slot mapping.
  144. block_table = seq_group_metadata.block_tables[seq_id]
  145. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  146. # where start_idx is max(0, seq_len - sliding_window).
  147. # For example, if the prompt len is 10, sliding window is 8, and
  148. # block size is 4, the first two tokens are masked and the slot
  149. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  150. start_idx = 0
  151. if self.sliding_window is not None:
  152. start_idx = max(0, seq_len - self.sliding_window)
  153. for i in range(computed_len, seq_len):
  154. if i < start_idx:
  155. slot_mapping.append(_PAD_SLOT_ID)
  156. continue
  157. block_number = block_table[i //
  158. self.block_size] # type: ignore
  159. block_offset = i % self.block_size # type: ignore
  160. slot = block_number * self.block_size + block_offset
  161. slot_mapping.append(slot)
  162. num_prompt_tokens = len(input_tokens)
  163. input_tokens = torch.tensor(input_tokens,
  164. dtype=torch.long,
  165. device=self.device) # type: ignore
  166. input_positions = torch.tensor(input_positions,
  167. dtype=torch.long,
  168. device=self.device) # type: ignore
  169. slot_mapping = torch.tensor(slot_mapping,
  170. dtype=torch.long,
  171. device=self.device) # type: ignore
  172. attn_metadata = self.attn_backend.make_metadata(
  173. is_prompt=True,
  174. seq_lens=seq_lens,
  175. seq_lens_tensor=torch.tensor([]),
  176. max_decode_seq_len=0,
  177. num_prefills=len(seq_lens),
  178. num_prefill_tokens=num_prompt_tokens,
  179. num_decode_tokens=0,
  180. block_tables=torch.tensor([]),
  181. slot_mapping=slot_mapping,
  182. )
  183. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  184. return (input_tokens, input_positions, attn_metadata, seq_lens,
  185. multi_modal_kwargs)
  186. def _prepare_decode(
  187. self,
  188. seq_group_metadata_list: List[SequenceGroupMetadata],
  189. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  190. assert len(seq_group_metadata_list) > 0
  191. input_tokens: List[int] = []
  192. input_positions: List[int] = []
  193. slot_mapping: List[int] = []
  194. seq_lens: List[int] = []
  195. block_tables: List[List[int]] = []
  196. for seq_group_metadata in seq_group_metadata_list:
  197. assert not seq_group_metadata.is_prompt
  198. assert seq_group_metadata.token_chunk_size == 1
  199. seq_ids = list(seq_group_metadata.seq_data.keys())
  200. for seq_id in seq_ids:
  201. seq_data = seq_group_metadata.seq_data[seq_id]
  202. generation_token = seq_data.get_last_token_id()
  203. input_tokens.append(generation_token)
  204. seq_len = seq_data.get_len()
  205. position = seq_len - 1
  206. input_positions.append(position)
  207. seq_len = seq_len if self.sliding_window is None else min(
  208. seq_len, self.sliding_window)
  209. seq_lens.append(seq_len)
  210. block_table = seq_group_metadata.block_tables[seq_id]
  211. block_number = block_table[position // self.block_size]
  212. block_offset = position % self.block_size
  213. slot = block_number * self.block_size + block_offset
  214. slot_mapping.append(slot)
  215. if self.sliding_window is not None:
  216. sliding_window_blocks = (self.sliding_window //
  217. self.block_size)
  218. block_table = block_table[-sliding_window_blocks:]
  219. block_tables.append(block_table)
  220. max_decode_seq_len = max(seq_lens)
  221. input_tokens = torch.tensor(input_tokens,
  222. dtype=torch.long,
  223. device=self.device)
  224. input_positions = torch.tensor(input_positions,
  225. dtype=torch.long,
  226. device=self.device)
  227. slot_mapping = torch.tensor(slot_mapping,
  228. dtype=torch.long,
  229. device=self.device)
  230. seq_lens_tensor = torch.tensor(seq_lens,
  231. dtype=torch.int,
  232. device=self.device)
  233. block_tables = make_tensor_with_pad(
  234. block_tables,
  235. pad=0,
  236. dtype=torch.int,
  237. device=self.device,
  238. )
  239. attn_metadata = self.attn_backend.make_metadata(
  240. is_prompt=False,
  241. slot_mapping=slot_mapping,
  242. seq_lens=seq_lens,
  243. seq_lens_tensor=seq_lens_tensor,
  244. max_decode_seq_len=max_decode_seq_len,
  245. num_prefill_tokens=0,
  246. num_decode_tokens=len(input_tokens),
  247. num_prefills=0,
  248. block_tables=block_tables,
  249. )
  250. return (
  251. input_tokens,
  252. input_positions,
  253. attn_metadata,
  254. )
  255. def make_model_input_from_broadcasted_tensor_dict(
  256. self,
  257. tensor_dict: Dict[str, Any],
  258. ) -> CPUModelInput:
  259. return CPUModelInput.from_broadcasted_tensor_dict(
  260. tensor_dict,
  261. attn_backend=self.attn_backend,
  262. )
  263. def prepare_model_input(
  264. self,
  265. seq_group_metadata_list: List[SequenceGroupMetadata],
  266. virtual_engine: int = 0,
  267. finished_requests_ids: Optional[List[str]] = None
  268. ) -> CPUModelInput:
  269. multi_modal_kwargs = None
  270. # NOTE: We assume that all sequences in the group are all prompts or
  271. # all decodes.
  272. is_prompt = seq_group_metadata_list[0].is_prompt
  273. # Prepare input tensors.
  274. if is_prompt:
  275. (input_tokens, input_positions, attn_metadata, seq_lens,
  276. multi_modal_kwargs
  277. ) = self._prepare_prompt(seq_group_metadata_list)
  278. else:
  279. (input_tokens, input_positions,
  280. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  281. seq_lens = []
  282. sampling_metadata = SamplingMetadata.prepare(
  283. seq_group_metadata_list,
  284. seq_lens,
  285. # query_lens is not needed if chunked prefill is not
  286. # supported. Since CPU worker doesn't support chunked prefill
  287. # just use seq_lens instead.
  288. seq_lens,
  289. self.device,
  290. pin_memory=False,
  291. generators=self.get_generators(finished_requests_ids))
  292. return CPUModelInput(
  293. input_tokens=input_tokens,
  294. input_positions=input_positions,
  295. attn_metadata=attn_metadata,
  296. sampling_metadata=sampling_metadata,
  297. multi_modal_kwargs=multi_modal_kwargs,
  298. )
  299. @torch.no_grad()
  300. def execute_model(
  301. self,
  302. model_input: CPUModelInput,
  303. kv_caches: List[torch.Tensor],
  304. intermediate_tensors: Optional[IntermediateTensors] = None,
  305. num_steps: int = 1,
  306. ) -> Optional[List[SamplerOutput]]:
  307. if num_steps > 1:
  308. raise ValueError(
  309. "CPU worker does not support multi-step execution.")
  310. model_executable = self.model
  311. execute_model_kwargs = {
  312. "input_ids":
  313. model_input.input_tokens,
  314. "positions":
  315. model_input.input_positions,
  316. "kv_caches":
  317. kv_caches,
  318. "attn_metadata":
  319. model_input.attn_metadata,
  320. **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
  321. device=self.device),
  322. }
  323. hidden_states = model_executable(**execute_model_kwargs)
  324. # Compute the logits.
  325. logits = self.model.compute_logits(hidden_states,
  326. model_input.sampling_metadata)
  327. # Only perform sampling in the driver worker.
  328. if not self.is_driver_worker:
  329. return []
  330. # Sample the next token.
  331. output = self.model.sample(
  332. logits=logits,
  333. sampling_metadata=model_input.sampling_metadata,
  334. )
  335. return [output]