cpu_model_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
  3. import torch
  4. from torch import nn
  5. from aphrodite.attention import AttentionMetadata, get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. PromptAdapterConfig, SchedulerConfig)
  9. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  10. SequenceGroupMetadata)
  11. from aphrodite.common.utils import make_tensor_with_pad
  12. from aphrodite.modeling import SamplingMetadata
  13. from aphrodite.modeling.model_loader import get_model
  14. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  15. MultiModalInputs)
  16. from aphrodite.task_handler.model_runner_base import (
  17. ModelRunnerBase, ModelRunnerInputBase,
  18. _add_attn_metadata_broadcastable_dict,
  19. _add_sampling_metadata_broadcastable_dict,
  20. _init_attn_metadata_from_tensor_dict,
  21. _init_sampling_metadata_from_tensor_dict)
  22. if TYPE_CHECKING:
  23. from aphrodite.attention.backends.abstract import AttentionBackend
  24. _PAD_SLOT_ID = -1
  25. @dataclass(frozen=True)
  26. class CPUModelInput(ModelRunnerInputBase):
  27. """
  28. Used by the CPUModelRunner.
  29. """
  30. input_tokens: Optional[torch.Tensor] = None
  31. input_positions: Optional[torch.Tensor] = None
  32. attn_metadata: Optional["AttentionMetadata"] = None
  33. sampling_metadata: Optional["SamplingMetadata"] = None
  34. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  35. virtual_engine: Optional[int] = None
  36. def as_broadcastable_tensor_dict(
  37. self) -> Dict[str, Union[int, torch.Tensor]]:
  38. tensor_dict = {
  39. "input_tokens": self.input_tokens,
  40. "input_positions": self.input_positions,
  41. "multi_modal_kwargs": self.multi_modal_kwargs,
  42. }
  43. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  44. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  45. self.sampling_metadata)
  46. return tensor_dict
  47. @classmethod
  48. def from_broadcasted_tensor_dict(
  49. cls: Type["CPUModelInput"],
  50. tensor_dict: Dict[str, Any],
  51. attn_backend: Optional["AttentionBackend"] = None
  52. ) -> "CPUModelInput":
  53. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  54. if attn_backend is not None:
  55. tensor_dict = _init_attn_metadata_from_tensor_dict(
  56. attn_backend, tensor_dict)
  57. return cls(**tensor_dict)
  58. class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
  59. def __init__(
  60. self,
  61. model_config: ModelConfig,
  62. parallel_config: ParallelConfig,
  63. scheduler_config: SchedulerConfig,
  64. device_config: DeviceConfig,
  65. cache_config: CacheConfig,
  66. load_config: LoadConfig,
  67. lora_config: Optional[LoRAConfig],
  68. kv_cache_dtype: Optional[str] = "auto",
  69. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  70. is_driver_worker: bool = False,
  71. *args,
  72. **kwargs,
  73. ):
  74. self.model_config = model_config
  75. self.parallel_config = parallel_config
  76. self.scheduler_config = scheduler_config
  77. # Currently, CPU worker doesn't support chunked prefill.
  78. assert self.scheduler_config.chunked_prefill_enabled is False
  79. self.device_config = device_config
  80. self.cache_config = cache_config
  81. self.lora_config = lora_config
  82. self.prompt_adapter_config = prompt_adapter_config
  83. self.load_config = load_config
  84. self.is_driver_worker = is_driver_worker
  85. self.device = self.device_config.device
  86. self.kv_cache_dtype = kv_cache_dtype
  87. self.sliding_window = model_config.get_sliding_window()
  88. self.block_size = cache_config.block_size
  89. self.attn_backend = get_attn_backend(
  90. self.model_config.get_head_size(),
  91. self.model_config.get_sliding_window(),
  92. self.model_config.dtype,
  93. self.kv_cache_dtype,
  94. self.block_size,
  95. self.model_config.is_attention_free(),
  96. )
  97. # Multi-modal data support
  98. self.mm_registry = MULTIMODAL_REGISTRY
  99. self.multi_modal_input_mapper = self.mm_registry \
  100. .create_input_mapper(self.model_config)
  101. self.mm_registry.init_mm_limits_per_prompt(self.model_config)
  102. # Lazy initialization.
  103. self.model: nn.Module # Set after init_Model
  104. def load_model(self) -> None:
  105. self.model = get_model(model_config=self.model_config,
  106. load_config=self.load_config,
  107. device_config=self.device_config,
  108. lora_config=self.lora_config,
  109. parallel_config=self.parallel_config,
  110. scheduler_config=self.scheduler_config,
  111. cache_config=self.cache_config)
  112. def _prepare_prompt(
  113. self,
  114. seq_group_metadata_list: List[SequenceGroupMetadata],
  115. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  116. BatchedTensorInputs]:
  117. assert len(seq_group_metadata_list) > 0
  118. input_tokens: List[int] = []
  119. input_positions: List[int] = []
  120. slot_mapping: List[int] = []
  121. seq_lens: List[int] = []
  122. multi_modal_inputs_list: List[MultiModalInputs] = []
  123. for seq_group_metadata in seq_group_metadata_list:
  124. assert seq_group_metadata.is_prompt
  125. seq_ids = list(seq_group_metadata.seq_data.keys())
  126. assert len(seq_ids) == 1
  127. seq_id = seq_ids[0]
  128. seq_data = seq_group_metadata.seq_data[seq_id]
  129. prompt_tokens = seq_data.get_token_ids()
  130. computed_len = seq_data.get_num_computed_tokens()
  131. seq_len = len(prompt_tokens)
  132. seq_lens.append(seq_len) # Prompt token num
  133. input_tokens.extend(prompt_tokens) # Token ids
  134. # Token position ids
  135. # NOTE: Here we assume that the first token in the prompt
  136. # is always the first token in the sequence.
  137. input_positions.extend(list(range(computed_len, seq_len)))
  138. mm_data = seq_group_metadata.multi_modal_data
  139. if mm_data:
  140. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  141. multi_modal_inputs_list.append(mm_kwargs)
  142. # Compute the slot mapping.
  143. block_table = seq_group_metadata.block_tables[seq_id]
  144. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  145. # where start_idx is max(0, seq_len - sliding_window).
  146. # For example, if the prompt len is 10, sliding window is 8, and
  147. # block size is 4, the first two tokens are masked and the slot
  148. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  149. start_idx = 0
  150. if self.sliding_window is not None:
  151. start_idx = max(0, seq_len - self.sliding_window)
  152. for i in range(computed_len, seq_len):
  153. if i < start_idx:
  154. slot_mapping.append(_PAD_SLOT_ID)
  155. continue
  156. block_number = block_table[i //
  157. self.block_size] # type: ignore
  158. block_offset = i % self.block_size # type: ignore
  159. slot = block_number * self.block_size + block_offset
  160. slot_mapping.append(slot)
  161. num_prompt_tokens = len(input_tokens)
  162. input_tokens = torch.tensor(input_tokens,
  163. dtype=torch.long,
  164. device=self.device) # type: ignore
  165. input_positions = torch.tensor(input_positions,
  166. dtype=torch.long,
  167. device=self.device) # type: ignore
  168. slot_mapping = torch.tensor(slot_mapping,
  169. dtype=torch.long,
  170. device=self.device) # type: ignore
  171. attn_metadata = self.attn_backend.make_metadata(
  172. is_prompt=True,
  173. seq_lens=seq_lens,
  174. seq_lens_tensor=torch.tensor([]),
  175. max_decode_seq_len=0,
  176. num_prefills=len(seq_lens),
  177. num_prefill_tokens=num_prompt_tokens,
  178. num_decode_tokens=0,
  179. block_tables=torch.tensor([]),
  180. slot_mapping=slot_mapping,
  181. )
  182. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  183. return (input_tokens, input_positions, attn_metadata, seq_lens,
  184. multi_modal_kwargs)
  185. def _prepare_decode(
  186. self,
  187. seq_group_metadata_list: List[SequenceGroupMetadata],
  188. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  189. assert len(seq_group_metadata_list) > 0
  190. input_tokens: List[int] = []
  191. input_positions: List[int] = []
  192. slot_mapping: List[int] = []
  193. seq_lens: List[int] = []
  194. block_tables: List[List[int]] = []
  195. for seq_group_metadata in seq_group_metadata_list:
  196. assert not seq_group_metadata.is_prompt
  197. assert seq_group_metadata.token_chunk_size == 1
  198. seq_ids = list(seq_group_metadata.seq_data.keys())
  199. for seq_id in seq_ids:
  200. seq_data = seq_group_metadata.seq_data[seq_id]
  201. generation_token = seq_data.get_last_token_id()
  202. input_tokens.append(generation_token)
  203. seq_len = seq_data.get_len()
  204. position = seq_len - 1
  205. input_positions.append(position)
  206. seq_len = seq_len if self.sliding_window is None else min(
  207. seq_len, self.sliding_window)
  208. seq_lens.append(seq_len)
  209. block_table = seq_group_metadata.block_tables[seq_id]
  210. block_number = block_table[position // self.block_size]
  211. block_offset = position % self.block_size
  212. slot = block_number * self.block_size + block_offset
  213. slot_mapping.append(slot)
  214. if self.sliding_window is not None:
  215. sliding_window_blocks = (self.sliding_window //
  216. self.block_size)
  217. block_table = block_table[-sliding_window_blocks:]
  218. block_tables.append(block_table)
  219. max_decode_seq_len = max(seq_lens)
  220. input_tokens = torch.tensor(input_tokens,
  221. dtype=torch.long,
  222. device=self.device)
  223. input_positions = torch.tensor(input_positions,
  224. dtype=torch.long,
  225. device=self.device)
  226. slot_mapping = torch.tensor(slot_mapping,
  227. dtype=torch.long,
  228. device=self.device)
  229. seq_lens_tensor = torch.tensor(seq_lens,
  230. dtype=torch.int,
  231. device=self.device)
  232. block_tables = make_tensor_with_pad(
  233. block_tables,
  234. pad=0,
  235. dtype=torch.int,
  236. device=self.device,
  237. )
  238. attn_metadata = self.attn_backend.make_metadata(
  239. is_prompt=False,
  240. slot_mapping=slot_mapping,
  241. seq_lens=seq_lens,
  242. seq_lens_tensor=seq_lens_tensor,
  243. max_decode_seq_len=max_decode_seq_len,
  244. num_prefill_tokens=0,
  245. num_decode_tokens=len(input_tokens),
  246. num_prefills=0,
  247. block_tables=block_tables,
  248. )
  249. return (
  250. input_tokens,
  251. input_positions,
  252. attn_metadata,
  253. )
  254. def make_model_input_from_broadcasted_tensor_dict(
  255. self,
  256. tensor_dict: Dict[str, Any],
  257. ) -> CPUModelInput:
  258. return CPUModelInput.from_broadcasted_tensor_dict(
  259. tensor_dict,
  260. attn_backend=self.attn_backend,
  261. )
  262. def prepare_model_input(
  263. self,
  264. seq_group_metadata_list: List[SequenceGroupMetadata],
  265. virtual_engine: int = 0,
  266. finished_requests_ids: Optional[List[str]] = None
  267. ) -> CPUModelInput:
  268. multi_modal_kwargs = None
  269. # NOTE: We assume that all sequences in the group are all prompts or
  270. # all decodes.
  271. is_prompt = seq_group_metadata_list[0].is_prompt
  272. # Prepare input tensors.
  273. if is_prompt:
  274. (input_tokens, input_positions, attn_metadata, seq_lens,
  275. multi_modal_kwargs
  276. ) = self._prepare_prompt(seq_group_metadata_list)
  277. else:
  278. (input_tokens, input_positions,
  279. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  280. seq_lens = []
  281. sampling_metadata = SamplingMetadata.prepare(
  282. seq_group_metadata_list,
  283. seq_lens,
  284. # query_lens is not needed if chunked prefill is not
  285. # supported. Since CPU worker doesn't support chunked prefill
  286. # just use seq_lens instead.
  287. seq_lens,
  288. self.device,
  289. pin_memory=False,
  290. generators=self.get_generators(finished_requests_ids))
  291. return CPUModelInput(
  292. input_tokens=input_tokens,
  293. input_positions=input_positions,
  294. attn_metadata=attn_metadata,
  295. sampling_metadata=sampling_metadata,
  296. multi_modal_kwargs=multi_modal_kwargs,
  297. )
  298. @torch.no_grad()
  299. def execute_model(
  300. self,
  301. model_input: CPUModelInput,
  302. kv_caches: List[torch.Tensor],
  303. intermediate_tensors: Optional[IntermediateTensors] = None,
  304. num_steps: int = 1,
  305. ) -> Optional[List[SamplerOutput]]:
  306. if num_steps > 1:
  307. raise ValueError(
  308. "CPU worker does not support multi-step execution.")
  309. model_executable = self.model
  310. execute_model_kwargs = {
  311. "input_ids":
  312. model_input.input_tokens,
  313. "positions":
  314. model_input.input_positions,
  315. "kv_caches":
  316. kv_caches,
  317. "attn_metadata":
  318. model_input.attn_metadata,
  319. **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
  320. device=self.device),
  321. }
  322. hidden_states = model_executable(**execute_model_kwargs)
  323. # Compute the logits.
  324. logits = self.model.compute_logits(hidden_states,
  325. model_input.sampling_metadata)
  326. # Only perform sampling in the driver worker.
  327. if not self.is_driver_worker:
  328. return []
  329. # Sample the next token.
  330. output = self.model.sample(
  331. logits=logits,
  332. sampling_metadata=model_input.sampling_metadata,
  333. )
  334. return [output]