cpu_model_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. from aphrodite.attention import AttentionMetadata, get_attn_backend
  5. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  6. LoRAConfig, ModelConfig, ParallelConfig,
  7. SchedulerConfig, VisionLanguageConfig)
  8. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  9. from aphrodite.common.utils import make_tensor_with_pad
  10. from aphrodite.distributed import broadcast_tensor_dict
  11. from aphrodite.modeling import SamplingMetadata
  12. from aphrodite.modeling.model_loader import get_model
  13. _PAD_SLOT_ID = -1
  14. class CPUModelRunner:
  15. def __init__(
  16. self,
  17. model_config: ModelConfig,
  18. parallel_config: ParallelConfig,
  19. scheduler_config: SchedulerConfig,
  20. device_config: DeviceConfig,
  21. cache_config: CacheConfig,
  22. load_config: LoadConfig,
  23. lora_config: Optional[LoRAConfig],
  24. vision_language_config: Optional[VisionLanguageConfig],
  25. kv_cache_dtype: Optional[str] = "auto",
  26. is_driver_worker: bool = False,
  27. *args,
  28. **kwargs,
  29. ):
  30. self.model_config = model_config
  31. self.parallel_config = parallel_config
  32. self.scheduler_config = scheduler_config
  33. # Currently, CPU worker doesn't support chunked prefill.
  34. assert self.scheduler_config.chunked_prefill_enabled is False
  35. self.device_config = device_config
  36. self.cache_config = cache_config
  37. self.lora_config = lora_config
  38. self.vision_language_config = vision_language_config
  39. self.load_config = load_config
  40. self.is_driver_worker = is_driver_worker
  41. self.device = self.device_config.device
  42. self.kv_cache_dtype = kv_cache_dtype
  43. self.sliding_window = model_config.get_sliding_window()
  44. self.block_size = cache_config.block_size
  45. self.attn_backend = get_attn_backend(
  46. self.model_config.get_num_attention_heads(self.parallel_config),
  47. self.model_config.get_head_size(),
  48. self.model_config.get_num_kv_heads(self.parallel_config),
  49. self.model_config.get_sliding_window(),
  50. self.model_config.dtype,
  51. self.kv_cache_dtype,
  52. self.block_size,
  53. )
  54. # Lazy initialization.
  55. self.model: nn.Module # Set after init_Model
  56. def load_model(self) -> None:
  57. self.model = get_model(
  58. model_config=self.model_config,
  59. load_config=self.load_config,
  60. device_config=self.device_config,
  61. vision_language_config=self.vision_language_config,
  62. lora_config=self.lora_config,
  63. parallel_config=self.parallel_config,
  64. scheduler_config=self.scheduler_config,
  65. cache_config=self.cache_config)
  66. def _prepare_prompt(
  67. self,
  68. seq_group_metadata_list: List[SequenceGroupMetadata],
  69. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  70. Optional[torch.Tensor]]:
  71. assert len(seq_group_metadata_list) > 0
  72. input_tokens: List[int] = []
  73. input_positions: List[int] = []
  74. slot_mapping: List[int] = []
  75. seq_lens: List[int] = []
  76. multi_modal_input_list: List[torch.Tensor] = []
  77. for seq_group_metadata in seq_group_metadata_list:
  78. assert seq_group_metadata.is_prompt
  79. seq_ids = list(seq_group_metadata.seq_data.keys())
  80. assert len(seq_ids) == 1
  81. seq_id = seq_ids[0]
  82. seq_data = seq_group_metadata.seq_data[seq_id]
  83. prompt_tokens = seq_data.get_token_ids()
  84. computed_len = seq_data.get_num_computed_tokens()
  85. seq_len = len(prompt_tokens)
  86. seq_lens.append(seq_len) # Prompt token num
  87. input_tokens.extend(prompt_tokens) # Token ids
  88. # Token position ids
  89. # NOTE: Here we assume that the first token in the prompt
  90. # is always the first token in the sequence.
  91. input_positions.extend(list(range(computed_len, seq_len)))
  92. if seq_group_metadata.multi_modal_data:
  93. multi_modal_input_list.append(
  94. seq_group_metadata.multi_modal_data.data)
  95. # Compute the slot mapping.
  96. block_table = seq_group_metadata.block_tables[seq_id]
  97. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  98. # where start_idx is max(0, seq_len - sliding_window).
  99. # For example, if the prompt len is 10, sliding window is 8, and
  100. # block size is 4, the first two tokens are masked and the slot
  101. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  102. start_idx = 0
  103. if self.sliding_window is not None:
  104. start_idx = max(0, seq_len - self.sliding_window)
  105. for i in range(computed_len, seq_len):
  106. if i < start_idx:
  107. slot_mapping.append(_PAD_SLOT_ID)
  108. continue
  109. block_number = block_table[i //
  110. self.block_size] # type: ignore
  111. block_offset = i % self.block_size # type: ignore
  112. slot = block_number * self.block_size + block_offset
  113. slot_mapping.append(slot)
  114. if multi_modal_input_list:
  115. assert self.vision_language_config, (
  116. "Multi-modal inputs are only supported by "
  117. "vision language models.")
  118. multi_modal_input = torch.cat(multi_modal_input_list,
  119. dim=0).to(self.device)
  120. else:
  121. multi_modal_input = None
  122. num_prompt_tokens = len(input_tokens)
  123. input_tokens = torch.tensor(input_tokens,
  124. dtype=torch.long,
  125. device=self.device) # type: ignore
  126. input_positions = torch.tensor(input_positions,
  127. dtype=torch.long,
  128. device=self.device) # type: ignore
  129. slot_mapping = torch.tensor(slot_mapping,
  130. dtype=torch.long,
  131. device=self.device) # type: ignore
  132. attn_metadata = self.attn_backend.make_metadata(
  133. is_prompt=True,
  134. seq_lens=seq_lens,
  135. seq_lens_tensor=None,
  136. max_decode_seq_len=None,
  137. num_prefills=len(seq_lens),
  138. num_prefill_tokens=num_prompt_tokens,
  139. num_decode_tokens=0,
  140. block_tables=torch.tensor([]),
  141. slot_mapping=slot_mapping,
  142. )
  143. return (input_tokens, input_positions, attn_metadata, seq_lens,
  144. multi_modal_input)
  145. def _prepare_decode(
  146. self,
  147. seq_group_metadata_list: List[SequenceGroupMetadata],
  148. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  149. assert len(seq_group_metadata_list) > 0
  150. input_tokens: List[int] = []
  151. input_positions: List[int] = []
  152. slot_mapping: List[int] = []
  153. seq_lens: List[int] = []
  154. block_tables: List[List[int]] = []
  155. for seq_group_metadata in seq_group_metadata_list:
  156. assert not seq_group_metadata.is_prompt
  157. assert seq_group_metadata.token_chunk_size == 1
  158. seq_ids = list(seq_group_metadata.seq_data.keys())
  159. for seq_id in seq_ids:
  160. seq_data = seq_group_metadata.seq_data[seq_id]
  161. generation_token = seq_data.get_last_token_id()
  162. input_tokens.append(generation_token)
  163. seq_len = seq_data.get_len()
  164. position = seq_len - 1
  165. input_positions.append(position)
  166. seq_len = seq_len if self.sliding_window is None else min(
  167. seq_len, self.sliding_window)
  168. seq_lens.append(seq_len)
  169. block_table = seq_group_metadata.block_tables[seq_id]
  170. block_number = block_table[position // self.block_size]
  171. block_offset = position % self.block_size
  172. slot = block_number * self.block_size + block_offset
  173. slot_mapping.append(slot)
  174. if self.sliding_window is not None:
  175. sliding_window_blocks = (self.sliding_window //
  176. self.block_size)
  177. block_table = block_table[-sliding_window_blocks:]
  178. block_tables.append(block_table)
  179. max_decode_seq_len = max(seq_lens)
  180. input_tokens = torch.tensor(input_tokens,
  181. dtype=torch.long,
  182. device=self.device)
  183. input_positions = torch.tensor(input_positions,
  184. dtype=torch.long,
  185. device=self.device)
  186. slot_mapping = torch.tensor(slot_mapping,
  187. dtype=torch.long,
  188. device=self.device)
  189. seq_lens_tensor = torch.tensor(seq_lens,
  190. dtype=torch.int,
  191. device=self.device)
  192. max_block_table_len = max(
  193. len(block_table) for block_table in block_tables)
  194. block_tables = make_tensor_with_pad(
  195. block_tables,
  196. max_len=max_block_table_len,
  197. pad=0,
  198. dtype=torch.int,
  199. device=self.device,
  200. )
  201. attn_metadata = self.attn_backend.make_metadata(
  202. is_prompt=False,
  203. slot_mapping=slot_mapping,
  204. seq_lens=seq_lens,
  205. seq_lens_tensor=seq_lens_tensor,
  206. max_decode_seq_len=max_decode_seq_len,
  207. num_prefill_tokens=0,
  208. num_decode_tokens=len(input_tokens),
  209. num_prefills=0,
  210. block_tables=block_tables,
  211. )
  212. return (
  213. input_tokens,
  214. input_positions,
  215. attn_metadata,
  216. )
  217. def prepare_input_tensors(
  218. self,
  219. seq_group_metadata_list: List[SequenceGroupMetadata],
  220. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  221. Optional[torch.Tensor]]:
  222. multi_modal_input = None
  223. if self.is_driver_worker:
  224. # NOTE: We assume that all sequences in the group are all prompts or
  225. # all decodes.
  226. is_prompt = seq_group_metadata_list[0].is_prompt
  227. # Prepare input tensors.
  228. if is_prompt:
  229. (input_tokens, input_positions, attn_metadata, seq_lens,
  230. multi_modal_input
  231. ) = self._prepare_prompt(seq_group_metadata_list)
  232. else:
  233. (input_tokens, input_positions,
  234. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  235. seq_lens = []
  236. sampling_metadata = SamplingMetadata.prepare(
  237. seq_group_metadata_list,
  238. seq_lens,
  239. # query_lens is not needed if chunked prefill is not
  240. # supported. Since CPU worker doesn't support chunked prefill
  241. # just use seq_lens instead.
  242. seq_lens,
  243. self.device,
  244. pin_memory=False)
  245. # Broadcast the metadata.
  246. metadata_dict = {
  247. "input_tokens": input_tokens,
  248. "input_positions": input_positions,
  249. "selected_token_indices":
  250. sampling_metadata.selected_token_indices,
  251. }
  252. metadata_dict.update(attn_metadata.asdict_zerocopy())
  253. broadcast_tensor_dict(metadata_dict, src=0)
  254. else:
  255. metadata_dict = broadcast_tensor_dict(src=0)
  256. input_tokens = metadata_dict.pop("input_tokens")
  257. input_positions = metadata_dict.pop("input_positions")
  258. selected_token_indices = metadata_dict.pop(
  259. "selected_token_indices")
  260. attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
  261. sampling_metadata = SamplingMetadata(
  262. seq_groups=None,
  263. seq_data=None,
  264. seq_lens=None,
  265. selected_token_indices=selected_token_indices,
  266. categorized_sample_indices=None,
  267. generators=None,
  268. )
  269. return (input_tokens, input_positions, attn_metadata,
  270. sampling_metadata, multi_modal_input)
  271. @torch.inference_mode()
  272. def execute_model(
  273. self,
  274. seq_group_metadata_list: List[SequenceGroupMetadata],
  275. kv_caches: List[torch.Tensor],
  276. ) -> Optional[SamplerOutput]:
  277. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  278. multi_modal_input
  279. ) = self.prepare_input_tensors(seq_group_metadata_list)
  280. model_executable = self.model
  281. execute_model_kwargs = {
  282. "input_ids": input_tokens,
  283. "positions": input_positions,
  284. "kv_caches": kv_caches,
  285. "attn_metadata": attn_metadata,
  286. }
  287. if self.vision_language_config:
  288. execute_model_kwargs.update({"image_input": multi_modal_input})
  289. hidden_states = model_executable(**execute_model_kwargs)
  290. # Compute the logits.
  291. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  292. # Only perform sampling in the driver worker.
  293. if not self.is_driver_worker:
  294. return None
  295. # Sample the next token.
  296. output = self.model.sample(
  297. logits=logits,
  298. sampling_metadata=sampling_metadata,
  299. )
  300. return output