cpu_model_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. from aphrodite.attention import AttentionMetadata, get_attn_backend
  5. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  6. LoRAConfig, ModelConfig, ParallelConfig,
  7. SchedulerConfig, VisionLanguageConfig)
  8. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  9. from aphrodite.common.utils import make_tensor_with_pad
  10. from aphrodite.distributed import broadcast_tensor_dict
  11. from aphrodite.modeling import SamplingMetadata
  12. from aphrodite.modeling.model_loader import get_model
  13. _PAD_SLOT_ID = -1
  14. class CPUModelRunner:
  15. def __init__(
  16. self,
  17. model_config: ModelConfig,
  18. parallel_config: ParallelConfig,
  19. scheduler_config: SchedulerConfig,
  20. device_config: DeviceConfig,
  21. cache_config: CacheConfig,
  22. load_config: LoadConfig,
  23. lora_config: Optional[LoRAConfig],
  24. vision_language_config: Optional[VisionLanguageConfig],
  25. kv_cache_dtype: Optional[str] = "auto",
  26. is_driver_worker: bool = False,
  27. *args,
  28. **kwargs,
  29. ):
  30. self.model_config = model_config
  31. self.parallel_config = parallel_config
  32. self.scheduler_config = scheduler_config
  33. # Currently, CPU worker doesn't support chunked prefill.
  34. assert self.scheduler_config.chunked_prefill_enabled is False
  35. self.device_config = device_config
  36. self.cache_config = cache_config
  37. self.lora_config = lora_config
  38. self.vision_language_config = vision_language_config
  39. self.load_config = load_config
  40. self.is_driver_worker = is_driver_worker
  41. self.device = self.device_config.device
  42. self.kv_cache_dtype = kv_cache_dtype
  43. self.sliding_window = model_config.get_sliding_window()
  44. self.block_size = cache_config.block_size
  45. self.attn_backend = get_attn_backend(
  46. self.model_config.get_num_attention_heads(self.parallel_config),
  47. self.model_config.get_head_size(),
  48. self.model_config.get_num_kv_heads(self.parallel_config),
  49. self.model_config.get_sliding_window(),
  50. self.model_config.dtype,
  51. self.kv_cache_dtype,
  52. self.block_size,
  53. )
  54. # Lazy initialization.
  55. self.model: nn.Module # Set after init_Model
  56. def load_model(self) -> None:
  57. self.model = get_model(
  58. model_config=self.model_config,
  59. load_config=self.load_config,
  60. device_config=self.device_config,
  61. vision_language_config=self.vision_language_config,
  62. lora_config=self.lora_config,
  63. parallel_config=self.parallel_config,
  64. scheduler_config=self.scheduler_config,
  65. cache_config=self.cache_config)
  66. def _prepare_prompt(
  67. self,
  68. seq_group_metadata_list: List[SequenceGroupMetadata],
  69. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  70. Optional[torch.Tensor]]:
  71. assert len(seq_group_metadata_list) > 0
  72. input_tokens: List[int] = []
  73. input_positions: List[int] = []
  74. slot_mapping: List[int] = []
  75. seq_lens: List[int] = []
  76. multi_modal_input_list: List[torch.Tensor] = []
  77. for seq_group_metadata in seq_group_metadata_list:
  78. assert seq_group_metadata.is_prompt
  79. seq_ids = list(seq_group_metadata.seq_data.keys())
  80. assert len(seq_ids) == 1
  81. seq_id = seq_ids[0]
  82. seq_data = seq_group_metadata.seq_data[seq_id]
  83. prompt_tokens = seq_data.get_token_ids()
  84. computed_len = seq_data.get_num_computed_tokens()
  85. seq_len = len(prompt_tokens)
  86. seq_lens.append(seq_len) # Prompt token num
  87. input_tokens.extend(prompt_tokens) # Token ids
  88. # Token position ids
  89. # NOTE: Here we assume that the first token in the prompt
  90. # is always the first token in the sequence.
  91. input_positions.extend(list(range(computed_len, seq_len)))
  92. if seq_group_metadata.multi_modal_data:
  93. multi_modal_input_list.append(
  94. seq_group_metadata.multi_modal_data.data)
  95. # Compute the slot mapping.
  96. block_table = seq_group_metadata.block_tables[seq_id]
  97. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  98. # where start_idx is max(0, seq_len - sliding_window).
  99. # For example, if the prompt len is 10, sliding window is 8, and
  100. # block size is 4, the first two tokens are masked and the slot
  101. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  102. start_idx = 0
  103. if self.sliding_window is not None:
  104. start_idx = max(0, seq_len - self.sliding_window)
  105. for i in range(computed_len, seq_len):
  106. if i < start_idx:
  107. slot_mapping.append(_PAD_SLOT_ID)
  108. continue
  109. block_number = block_table[i //
  110. self.block_size] # type: ignore
  111. block_offset = i % self.block_size # type: ignore
  112. slot = block_number * self.block_size + block_offset
  113. slot_mapping.append(slot)
  114. if multi_modal_input_list:
  115. assert self.vision_language_config, (
  116. "Multi-modal inputs are only supported by "
  117. "vision language models.")
  118. multi_modal_input = torch.cat(multi_modal_input_list,
  119. dim=0).to(self.device)
  120. else:
  121. multi_modal_input = None
  122. num_prompt_tokens = len(input_tokens)
  123. input_tokens = torch.tensor(input_tokens,
  124. dtype=torch.long,
  125. device=self.device) # type: ignore
  126. input_positions = torch.tensor(input_positions,
  127. dtype=torch.long,
  128. device=self.device) # type: ignore
  129. slot_mapping = torch.tensor(slot_mapping,
  130. dtype=torch.long,
  131. device=self.device) # type: ignore
  132. attn_metadata = self.attn_backend.make_metadata(
  133. is_prompt=True,
  134. seq_lens=seq_lens,
  135. seq_lens_tensor=None,
  136. max_decode_seq_len=None,
  137. num_prefills=len(seq_lens),
  138. num_prefill_tokens=num_prompt_tokens,
  139. num_decode_tokens=0,
  140. prefill_metadata=None,
  141. decode_metadata=None,
  142. block_tables=torch.tensor([]),
  143. slot_mapping=slot_mapping,
  144. )
  145. return (input_tokens, input_positions, attn_metadata, seq_lens,
  146. multi_modal_input)
  147. def _prepare_decode(
  148. self,
  149. seq_group_metadata_list: List[SequenceGroupMetadata],
  150. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  151. assert len(seq_group_metadata_list) > 0
  152. input_tokens: List[int] = []
  153. input_positions: List[int] = []
  154. slot_mapping: List[int] = []
  155. seq_lens: List[int] = []
  156. block_tables: List[List[int]] = []
  157. for seq_group_metadata in seq_group_metadata_list:
  158. assert not seq_group_metadata.is_prompt
  159. assert seq_group_metadata.token_chunk_size == 1
  160. seq_ids = list(seq_group_metadata.seq_data.keys())
  161. for seq_id in seq_ids:
  162. seq_data = seq_group_metadata.seq_data[seq_id]
  163. generation_token = seq_data.get_last_token_id()
  164. input_tokens.append(generation_token)
  165. seq_len = seq_data.get_len()
  166. position = seq_len - 1
  167. input_positions.append(position)
  168. seq_len = seq_len if self.sliding_window is None else min(
  169. seq_len, self.sliding_window)
  170. seq_lens.append(seq_len)
  171. block_table = seq_group_metadata.block_tables[seq_id]
  172. block_number = block_table[position // self.block_size]
  173. block_offset = position % self.block_size
  174. slot = block_number * self.block_size + block_offset
  175. slot_mapping.append(slot)
  176. if self.sliding_window is not None:
  177. sliding_window_blocks = (self.sliding_window //
  178. self.block_size)
  179. block_table = block_table[-sliding_window_blocks:]
  180. block_tables.append(block_table)
  181. max_decode_seq_len = max(seq_lens)
  182. input_tokens = torch.tensor(input_tokens,
  183. dtype=torch.long,
  184. device=self.device)
  185. input_positions = torch.tensor(input_positions,
  186. dtype=torch.long,
  187. device=self.device)
  188. slot_mapping = torch.tensor(slot_mapping,
  189. dtype=torch.long,
  190. device=self.device)
  191. seq_lens_tensor = torch.tensor(seq_lens,
  192. dtype=torch.int,
  193. device=self.device)
  194. max_block_table_len = max(
  195. len(block_table) for block_table in block_tables)
  196. block_tables = make_tensor_with_pad(
  197. block_tables,
  198. max_len=max_block_table_len,
  199. pad=0,
  200. dtype=torch.int,
  201. device=self.device,
  202. )
  203. attn_metadata = self.attn_backend.make_metadata(
  204. is_prompt=False,
  205. slot_mapping=slot_mapping,
  206. seq_lens=seq_lens,
  207. seq_lens_tensor=seq_lens_tensor,
  208. max_decode_seq_len=max_decode_seq_len,
  209. num_prefill_tokens=0,
  210. num_decode_tokens=len(input_tokens),
  211. num_prefills=0,
  212. block_tables=block_tables,
  213. )
  214. return (
  215. input_tokens,
  216. input_positions,
  217. attn_metadata,
  218. )
  219. def prepare_input_tensors(
  220. self,
  221. seq_group_metadata_list: List[SequenceGroupMetadata],
  222. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  223. Optional[torch.Tensor]]:
  224. multi_modal_input = None
  225. if self.is_driver_worker:
  226. # NOTE: We assume that all sequences in the group are all prompts or
  227. # all decodes.
  228. is_prompt = seq_group_metadata_list[0].is_prompt
  229. # Prepare input tensors.
  230. if is_prompt:
  231. (input_tokens, input_positions, attn_metadata, seq_lens,
  232. multi_modal_input
  233. ) = self._prepare_prompt(seq_group_metadata_list)
  234. else:
  235. (input_tokens, input_positions,
  236. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  237. seq_lens = []
  238. sampling_metadata = SamplingMetadata.prepare(
  239. seq_group_metadata_list,
  240. seq_lens,
  241. # query_lens is not needed if chunked prefill is not
  242. # supported. Since CPU worker doesn't support chunked prefill
  243. # just use seq_lens instead.
  244. seq_lens,
  245. self.device,
  246. pin_memory=False)
  247. # Broadcast the metadata.
  248. metadata_dict = {
  249. "input_tokens": input_tokens,
  250. "input_positions": input_positions,
  251. "selected_token_indices":
  252. sampling_metadata.selected_token_indices,
  253. }
  254. metadata_dict.update(attn_metadata.asdict_zerocopy())
  255. broadcast_tensor_dict(metadata_dict, src=0)
  256. else:
  257. metadata_dict = broadcast_tensor_dict(src=0)
  258. input_tokens = metadata_dict.pop("input_tokens")
  259. input_positions = metadata_dict.pop("input_positions")
  260. selected_token_indices = metadata_dict.pop(
  261. "selected_token_indices")
  262. attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
  263. sampling_metadata = SamplingMetadata(
  264. seq_groups=None,
  265. seq_data=None,
  266. seq_lens=None,
  267. selected_token_indices=selected_token_indices,
  268. categorized_sample_indices=None,
  269. generators=None,
  270. )
  271. return (input_tokens, input_positions, attn_metadata,
  272. sampling_metadata, multi_modal_input)
  273. @torch.inference_mode()
  274. def execute_model(
  275. self,
  276. seq_group_metadata_list: List[SequenceGroupMetadata],
  277. kv_caches: List[torch.Tensor],
  278. ) -> Optional[SamplerOutput]:
  279. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  280. multi_modal_input
  281. ) = self.prepare_input_tensors(seq_group_metadata_list)
  282. model_executable = self.model
  283. execute_model_kwargs = {
  284. "input_ids": input_tokens,
  285. "positions": input_positions,
  286. "kv_caches": kv_caches,
  287. "attn_metadata": attn_metadata,
  288. }
  289. if self.vision_language_config:
  290. execute_model_kwargs.update({"image_input": multi_modal_input})
  291. hidden_states = model_executable(**execute_model_kwargs)
  292. # Compute the logits.
  293. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  294. # Only perform sampling in the driver worker.
  295. if not self.is_driver_worker:
  296. return None
  297. # Sample the next token.
  298. output = self.model.sample(
  299. logits=logits,
  300. sampling_metadata=sampling_metadata,
  301. )
  302. return output