cpu_model_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. from dataclasses import dataclass
  2. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
  3. import torch
  4. from torch import nn
  5. from aphrodite.attention import AttentionMetadata, get_attn_backend
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. LoRAConfig, ModelConfig, ParallelConfig,
  8. PromptAdapterConfig, SchedulerConfig)
  9. from aphrodite.common.sequence import (IntermediateTensors,
  10. SequenceGroupMetadata)
  11. from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_ERR_STRS,
  12. make_tensor_with_pad)
  13. from aphrodite.modeling.layers.sampler import SamplerOutput
  14. from aphrodite.modeling.model_loader import get_model
  15. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  16. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
  17. MultiModalInputs)
  18. from aphrodite.worker.model_runner_base import (
  19. ModelRunnerBase, ModelRunnerInputBase,
  20. _add_attn_metadata_broadcastable_dict,
  21. _add_sampling_metadata_broadcastable_dict,
  22. _init_attn_metadata_from_tensor_dict,
  23. _init_sampling_metadata_from_tensor_dict)
  24. if TYPE_CHECKING:
  25. from aphrodite.attention.backends.abstract import AttentionBackend
  26. _PAD_SLOT_ID = -1
  27. @dataclass(frozen=True)
  28. class CPUModelInput(ModelRunnerInputBase):
  29. """
  30. Used by the CPUModelRunner.
  31. """
  32. input_tokens: Optional[torch.Tensor] = None
  33. input_positions: Optional[torch.Tensor] = None
  34. attn_metadata: Optional["AttentionMetadata"] = None
  35. sampling_metadata: Optional["SamplingMetadata"] = None
  36. multi_modal_kwargs: Optional[BatchedTensorInputs] = None
  37. virtual_engine: Optional[int] = None
  38. def as_broadcastable_tensor_dict(
  39. self) -> Dict[str, Union[int, torch.Tensor]]:
  40. tensor_dict = {
  41. "input_tokens": self.input_tokens,
  42. "input_positions": self.input_positions,
  43. "multi_modal_kwargs": self.multi_modal_kwargs,
  44. }
  45. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  46. _add_sampling_metadata_broadcastable_dict(tensor_dict,
  47. self.sampling_metadata)
  48. return tensor_dict
  49. @classmethod
  50. def from_broadcasted_tensor_dict(
  51. cls: Type["CPUModelInput"],
  52. tensor_dict: Dict[str, Any],
  53. attn_backend: Optional["AttentionBackend"] = None
  54. ) -> "CPUModelInput":
  55. tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
  56. if attn_backend is not None:
  57. tensor_dict = _init_attn_metadata_from_tensor_dict(
  58. attn_backend, tensor_dict)
  59. return cls(**tensor_dict)
  60. class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
  61. def __init__(
  62. self,
  63. model_config: ModelConfig,
  64. parallel_config: ParallelConfig,
  65. scheduler_config: SchedulerConfig,
  66. device_config: DeviceConfig,
  67. cache_config: CacheConfig,
  68. load_config: LoadConfig,
  69. lora_config: Optional[LoRAConfig],
  70. kv_cache_dtype: Optional[str] = "auto",
  71. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  72. is_driver_worker: bool = False,
  73. *args,
  74. **kwargs,
  75. ):
  76. self.model_config = model_config
  77. self.parallel_config = parallel_config
  78. self.scheduler_config = scheduler_config
  79. # Currently, CPU worker doesn't support chunked prefill.
  80. assert self.scheduler_config.chunked_prefill_enabled is False
  81. self.device_config = device_config
  82. self.cache_config = cache_config
  83. self.lora_config = lora_config
  84. self.prompt_adapter_config = prompt_adapter_config
  85. self.load_config = load_config
  86. self.is_driver_worker = is_driver_worker
  87. self.device = self.device_config.device
  88. self.kv_cache_dtype = kv_cache_dtype
  89. self.sliding_window = model_config.get_sliding_window()
  90. self.block_size = cache_config.block_size
  91. self.attn_backend = get_attn_backend(
  92. self.model_config.get_head_size(),
  93. self.model_config.get_sliding_window(),
  94. self.model_config.dtype,
  95. self.kv_cache_dtype,
  96. self.block_size,
  97. self.model_config.is_attention_free(),
  98. )
  99. # Multi-modal data support
  100. self.mm_registry = MULTIMODAL_REGISTRY
  101. self.multi_modal_input_mapper = self.mm_registry \
  102. .create_input_mapper(self.model_config)
  103. self.mm_registry.init_mm_limits_per_prompt(self.model_config)
  104. # Lazy initialization.
  105. self.model: nn.Module # Set after init_Model
  106. if self.model_config.is_encoder_decoder_model:
  107. raise NotImplementedError(
  108. STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
  109. def load_model(self) -> None:
  110. self.model = get_model(model_config=self.model_config,
  111. load_config=self.load_config,
  112. device_config=self.device_config,
  113. lora_config=self.lora_config,
  114. parallel_config=self.parallel_config,
  115. scheduler_config=self.scheduler_config,
  116. cache_config=self.cache_config)
  117. def _prepare_prompt(
  118. self,
  119. seq_group_metadata_list: List[SequenceGroupMetadata],
  120. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  121. BatchedTensorInputs]:
  122. assert len(seq_group_metadata_list) > 0
  123. input_tokens: List[int] = []
  124. input_positions: List[int] = []
  125. slot_mapping: List[int] = []
  126. seq_lens: List[int] = []
  127. multi_modal_inputs_list: List[MultiModalInputs] = []
  128. for seq_group_metadata in seq_group_metadata_list:
  129. assert seq_group_metadata.is_prompt
  130. seq_ids = list(seq_group_metadata.seq_data.keys())
  131. assert len(seq_ids) == 1
  132. seq_id = seq_ids[0]
  133. seq_data = seq_group_metadata.seq_data[seq_id]
  134. prompt_tokens = seq_data.get_token_ids()
  135. computed_len = seq_data.get_num_computed_tokens()
  136. seq_len = len(prompt_tokens)
  137. seq_lens.append(seq_len) # Prompt token num
  138. input_tokens.extend(prompt_tokens) # Token ids
  139. # Token position ids
  140. # NOTE: Here we assume that the first token in the prompt
  141. # is always the first token in the sequence.
  142. input_positions.extend(list(range(computed_len, seq_len)))
  143. mm_data = seq_group_metadata.multi_modal_data
  144. if mm_data:
  145. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  146. multi_modal_inputs_list.append(mm_kwargs)
  147. # Compute the slot mapping.
  148. block_table = seq_group_metadata.block_tables[seq_id]
  149. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  150. # where start_idx is max(0, seq_len - sliding_window).
  151. # For example, if the prompt len is 10, sliding window is 8, and
  152. # block size is 4, the first two tokens are masked and the slot
  153. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  154. start_idx = 0
  155. if self.sliding_window is not None:
  156. start_idx = max(0, seq_len - self.sliding_window)
  157. for i in range(computed_len, seq_len):
  158. if i < start_idx:
  159. slot_mapping.append(_PAD_SLOT_ID)
  160. continue
  161. block_number = block_table[i //
  162. self.block_size] # type: ignore
  163. block_offset = i % self.block_size # type: ignore
  164. slot = block_number * self.block_size + block_offset
  165. slot_mapping.append(slot)
  166. num_prompt_tokens = len(input_tokens)
  167. input_tokens = torch.tensor(input_tokens,
  168. dtype=torch.long,
  169. device=self.device) # type: ignore
  170. input_positions = torch.tensor(input_positions,
  171. dtype=torch.long,
  172. device=self.device) # type: ignore
  173. slot_mapping = torch.tensor(slot_mapping,
  174. dtype=torch.long,
  175. device=self.device) # type: ignore
  176. attn_metadata = self.attn_backend.make_metadata(
  177. is_prompt=True,
  178. seq_lens=seq_lens,
  179. seq_lens_tensor=torch.tensor([]),
  180. max_decode_seq_len=0,
  181. num_prefills=len(seq_lens),
  182. num_prefill_tokens=num_prompt_tokens,
  183. num_decode_tokens=0,
  184. block_tables=torch.tensor([]),
  185. slot_mapping=slot_mapping,
  186. )
  187. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
  188. return (input_tokens, input_positions, attn_metadata, seq_lens,
  189. multi_modal_kwargs)
  190. def _prepare_decode(
  191. self,
  192. seq_group_metadata_list: List[SequenceGroupMetadata],
  193. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  194. assert len(seq_group_metadata_list) > 0
  195. input_tokens: List[int] = []
  196. input_positions: List[int] = []
  197. slot_mapping: List[int] = []
  198. seq_lens: List[int] = []
  199. block_tables: List[List[int]] = []
  200. for seq_group_metadata in seq_group_metadata_list:
  201. assert not seq_group_metadata.is_prompt
  202. assert seq_group_metadata.token_chunk_size == 1
  203. seq_ids = list(seq_group_metadata.seq_data.keys())
  204. for seq_id in seq_ids:
  205. seq_data = seq_group_metadata.seq_data[seq_id]
  206. generation_token = seq_data.get_last_token_id()
  207. input_tokens.append(generation_token)
  208. seq_len = seq_data.get_len()
  209. position = seq_len - 1
  210. input_positions.append(position)
  211. seq_len = seq_len if self.sliding_window is None else min(
  212. seq_len, self.sliding_window)
  213. seq_lens.append(seq_len)
  214. block_table = seq_group_metadata.block_tables[seq_id]
  215. block_number = block_table[position // self.block_size]
  216. block_offset = position % self.block_size
  217. slot = block_number * self.block_size + block_offset
  218. slot_mapping.append(slot)
  219. if self.sliding_window is not None:
  220. sliding_window_blocks = (self.sliding_window //
  221. self.block_size)
  222. block_table = block_table[-sliding_window_blocks:]
  223. block_tables.append(block_table)
  224. max_decode_seq_len = max(seq_lens)
  225. input_tokens = torch.tensor(input_tokens,
  226. dtype=torch.long,
  227. device=self.device)
  228. input_positions = torch.tensor(input_positions,
  229. dtype=torch.long,
  230. device=self.device)
  231. slot_mapping = torch.tensor(slot_mapping,
  232. dtype=torch.long,
  233. device=self.device)
  234. seq_lens_tensor = torch.tensor(seq_lens,
  235. dtype=torch.int,
  236. device=self.device)
  237. block_tables = make_tensor_with_pad(
  238. block_tables,
  239. pad=0,
  240. dtype=torch.int,
  241. device=self.device,
  242. )
  243. attn_metadata = self.attn_backend.make_metadata(
  244. is_prompt=False,
  245. slot_mapping=slot_mapping,
  246. seq_lens=seq_lens,
  247. seq_lens_tensor=seq_lens_tensor,
  248. max_decode_seq_len=max_decode_seq_len,
  249. num_prefill_tokens=0,
  250. num_decode_tokens=len(input_tokens),
  251. num_prefills=0,
  252. block_tables=block_tables,
  253. )
  254. return (
  255. input_tokens,
  256. input_positions,
  257. attn_metadata,
  258. )
  259. def make_model_input_from_broadcasted_tensor_dict(
  260. self,
  261. tensor_dict: Dict[str, Any],
  262. ) -> CPUModelInput:
  263. return CPUModelInput.from_broadcasted_tensor_dict(
  264. tensor_dict,
  265. attn_backend=self.attn_backend,
  266. )
  267. def prepare_model_input(
  268. self,
  269. seq_group_metadata_list: List[SequenceGroupMetadata],
  270. virtual_engine: int = 0,
  271. finished_requests_ids: Optional[List[str]] = None
  272. ) -> CPUModelInput:
  273. multi_modal_kwargs = None
  274. # NOTE: We assume that all sequences in the group are all prompts or
  275. # all decodes.
  276. is_prompt = seq_group_metadata_list[0].is_prompt
  277. # Prepare input tensors.
  278. if is_prompt:
  279. (input_tokens, input_positions, attn_metadata, seq_lens,
  280. multi_modal_kwargs
  281. ) = self._prepare_prompt(seq_group_metadata_list)
  282. else:
  283. (input_tokens, input_positions,
  284. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  285. seq_lens = []
  286. sampling_metadata = SamplingMetadata.prepare(
  287. seq_group_metadata_list,
  288. seq_lens,
  289. # query_lens is not needed if chunked prefill is not
  290. # supported. Since CPU worker doesn't support chunked prefill
  291. # just use seq_lens instead.
  292. seq_lens,
  293. self.device,
  294. pin_memory=False,
  295. generators=self.get_generators(finished_requests_ids))
  296. return CPUModelInput(
  297. input_tokens=input_tokens,
  298. input_positions=input_positions,
  299. attn_metadata=attn_metadata,
  300. sampling_metadata=sampling_metadata,
  301. multi_modal_kwargs=multi_modal_kwargs,
  302. )
  303. @torch.no_grad()
  304. def execute_model(
  305. self,
  306. model_input: CPUModelInput,
  307. kv_caches: List[torch.Tensor],
  308. intermediate_tensors: Optional[IntermediateTensors] = None,
  309. num_steps: int = 1,
  310. ) -> Optional[List[SamplerOutput]]:
  311. if num_steps > 1:
  312. raise ValueError(
  313. "CPU worker does not support multi-step execution.")
  314. model_executable = self.model
  315. execute_model_kwargs = {
  316. "input_ids":
  317. model_input.input_tokens,
  318. "positions":
  319. model_input.input_positions,
  320. "kv_caches":
  321. kv_caches,
  322. "attn_metadata":
  323. model_input.attn_metadata,
  324. **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
  325. device=self.device),
  326. }
  327. hidden_states = model_executable(**execute_model_kwargs)
  328. # Compute the logits.
  329. logits = self.model.compute_logits(hidden_states,
  330. model_input.sampling_metadata)
  331. # Only perform sampling in the driver worker.
  332. if not self.is_driver_worker:
  333. return []
  334. # Sample the next token.
  335. output = self.model.sample(
  336. logits=logits,
  337. sampling_metadata=model_input.sampling_metadata,
  338. )
  339. return [output]