cpu_model_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. from aphrodite.attention import AttentionMetadata, get_attn_backend
  5. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  6. LoRAConfig, ModelConfig, ParallelConfig,
  7. SchedulerConfig, VisionLanguageConfig)
  8. from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
  9. from aphrodite.common.utils import make_tensor_with_pad
  10. from aphrodite.distributed import broadcast_tensor_dict
  11. from aphrodite.modeling import SamplingMetadata
  12. from aphrodite.modeling.model_loader import get_model
  13. _PAD_SLOT_ID = -1
  14. class CPUModelRunner:
  15. def __init__(
  16. self,
  17. model_config: ModelConfig,
  18. parallel_config: ParallelConfig,
  19. scheduler_config: SchedulerConfig,
  20. device_config: DeviceConfig,
  21. cache_config: CacheConfig,
  22. load_config: LoadConfig,
  23. lora_config: Optional[LoRAConfig],
  24. vision_language_config: Optional[VisionLanguageConfig],
  25. kv_cache_dtype: Optional[str] = "auto",
  26. is_driver_worker: bool = False,
  27. *args,
  28. **kwargs,
  29. ):
  30. self.model_config = model_config
  31. self.parallel_config = parallel_config
  32. self.scheduler_config = scheduler_config
  33. # Currently, CPU worker doesn't support chunked prefill.
  34. assert self.scheduler_config.chunked_prefill_enabled is False
  35. self.device_config = device_config
  36. self.cache_config = cache_config
  37. self.lora_config = lora_config
  38. self.vision_language_config = vision_language_config
  39. self.load_config = load_config
  40. self.is_driver_worker = is_driver_worker
  41. self.device = self.device_config.device
  42. self.kv_cache_dtype = kv_cache_dtype
  43. self.sliding_window = model_config.get_sliding_window()
  44. self.block_size = cache_config.block_size
  45. self.attn_backend = get_attn_backend(self.model_config.dtype)
  46. # Lazy initialization.
  47. self.model: nn.Module # Set after init_Model
  48. def load_model(self) -> None:
  49. self.model = get_model(
  50. model_config=self.model_config,
  51. load_config=self.load_config,
  52. device_config=self.device_config,
  53. vision_language_config=self.vision_language_config,
  54. lora_config=self.lora_config,
  55. parallel_config=self.parallel_config,
  56. scheduler_config=self.scheduler_config)
  57. def _prepare_prompt(
  58. self,
  59. seq_group_metadata_list: List[SequenceGroupMetadata],
  60. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
  61. Optional[torch.Tensor]]:
  62. assert len(seq_group_metadata_list) > 0
  63. input_tokens: List[int] = []
  64. input_positions: List[int] = []
  65. slot_mapping: List[int] = []
  66. seq_lens: List[int] = []
  67. multi_modal_input_list: List[torch.Tensor] = []
  68. for seq_group_metadata in seq_group_metadata_list:
  69. assert seq_group_metadata.is_prompt
  70. seq_ids = list(seq_group_metadata.seq_data.keys())
  71. assert len(seq_ids) == 1
  72. seq_id = seq_ids[0]
  73. seq_data = seq_group_metadata.seq_data[seq_id]
  74. prompt_tokens = seq_data.get_token_ids()
  75. computed_len = seq_data.get_num_computed_tokens()
  76. seq_len = len(prompt_tokens)
  77. seq_lens.append(seq_len) # Prompt token num
  78. input_tokens.extend(prompt_tokens) # Token ids
  79. # Token position ids
  80. # NOTE: Here we assume that the first token in the prompt
  81. # is always the first token in the sequence.
  82. input_positions.extend(list(range(computed_len, seq_len)))
  83. if seq_group_metadata.multi_modal_data:
  84. multi_modal_input_list.append(
  85. seq_group_metadata.multi_modal_data.data)
  86. # Compute the slot mapping.
  87. block_table = seq_group_metadata.block_tables[seq_id]
  88. # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
  89. # where start_idx is max(0, seq_len - sliding_window).
  90. # For example, if the prompt len is 10, sliding window is 8, and
  91. # block size is 4, the first two tokens are masked and the slot
  92. # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
  93. start_idx = 0
  94. if self.sliding_window is not None:
  95. start_idx = max(0, seq_len - self.sliding_window)
  96. for i in range(computed_len, seq_len):
  97. if i < start_idx:
  98. slot_mapping.append(_PAD_SLOT_ID)
  99. continue
  100. block_number = block_table[i //
  101. self.block_size] # type: ignore
  102. block_offset = i % self.block_size # type: ignore
  103. slot = block_number * self.block_size + block_offset
  104. slot_mapping.append(slot)
  105. if multi_modal_input_list:
  106. assert self.vision_language_config, (
  107. "Multi-modal inputs are only supported by "
  108. "vision language models.")
  109. multi_modal_input = torch.cat(multi_modal_input_list,
  110. dim=0).to(self.device)
  111. else:
  112. multi_modal_input = None
  113. num_prompt_tokens = len(input_tokens)
  114. input_tokens = torch.tensor(input_tokens,
  115. dtype=torch.long,
  116. device=self.device) # type: ignore
  117. input_positions = torch.tensor(input_positions,
  118. dtype=torch.long,
  119. device=self.device) # type: ignore
  120. slot_mapping = torch.tensor(slot_mapping,
  121. dtype=torch.long,
  122. device=self.device) # type: ignore
  123. attn_metadata = self.attn_backend.make_metadata(
  124. is_prompt=True,
  125. seq_lens=seq_lens,
  126. seq_lens_tensor=None,
  127. max_seq_len=None,
  128. num_prefills=len(seq_lens),
  129. num_prefill_tokens=num_prompt_tokens,
  130. num_decode_tokens=0,
  131. prefill_metadata=None,
  132. decode_metadata=None,
  133. block_tables=torch.tensor([]),
  134. slot_mapping=slot_mapping,
  135. kv_cache_dtype=self.kv_cache_dtype,
  136. )
  137. return (input_tokens, input_positions, attn_metadata, seq_lens,
  138. multi_modal_input)
  139. def _prepare_decode(
  140. self,
  141. seq_group_metadata_list: List[SequenceGroupMetadata],
  142. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
  143. assert len(seq_group_metadata_list) > 0
  144. input_tokens: List[int] = []
  145. input_positions: List[int] = []
  146. slot_mapping: List[int] = []
  147. seq_lens: List[int] = []
  148. block_tables: List[List[int]] = []
  149. for seq_group_metadata in seq_group_metadata_list:
  150. assert not seq_group_metadata.is_prompt
  151. assert seq_group_metadata.token_chunk_size == 1
  152. seq_ids = list(seq_group_metadata.seq_data.keys())
  153. for seq_id in seq_ids:
  154. seq_data = seq_group_metadata.seq_data[seq_id]
  155. generation_token = seq_data.get_last_token_id()
  156. input_tokens.append(generation_token)
  157. seq_len = seq_data.get_len()
  158. position = seq_len - 1
  159. input_positions.append(position)
  160. seq_len = seq_len if self.sliding_window is None else min(
  161. seq_len, self.sliding_window)
  162. seq_lens.append(seq_len)
  163. block_table = seq_group_metadata.block_tables[seq_id]
  164. block_number = block_table[position // self.block_size]
  165. block_offset = position % self.block_size
  166. slot = block_number * self.block_size + block_offset
  167. slot_mapping.append(slot)
  168. if self.sliding_window is not None:
  169. sliding_window_blocks = (self.sliding_window //
  170. self.block_size)
  171. block_table = block_table[-sliding_window_blocks:]
  172. block_tables.append(block_table)
  173. max_seq_len = max(seq_lens)
  174. input_tokens = torch.tensor(input_tokens,
  175. dtype=torch.long,
  176. device=self.device)
  177. input_positions = torch.tensor(input_positions,
  178. dtype=torch.long,
  179. device=self.device)
  180. slot_mapping = torch.tensor(slot_mapping,
  181. dtype=torch.long,
  182. device=self.device)
  183. seq_lens_tensor = torch.tensor(seq_lens,
  184. dtype=torch.int,
  185. device=self.device)
  186. max_block_table_len = max(
  187. len(block_table) for block_table in block_tables)
  188. block_tables = make_tensor_with_pad(
  189. block_tables,
  190. max_len=max_block_table_len,
  191. pad=0,
  192. dtype=torch.int,
  193. device=self.device,
  194. )
  195. attn_metadata = self.attn_backend.make_metadata(
  196. is_prompt=False,
  197. slot_mapping=slot_mapping,
  198. seq_lens=seq_lens,
  199. seq_lens_tensor=seq_lens_tensor,
  200. max_seq_len=max_seq_len,
  201. num_prefill_tokens=0,
  202. num_decode_tokens=len(input_tokens),
  203. num_prefills=0,
  204. prefill_metadata=None,
  205. decode_metadata=None,
  206. block_tables=block_tables,
  207. kv_cache_dtype=self.kv_cache_dtype,
  208. )
  209. return (
  210. input_tokens,
  211. input_positions,
  212. attn_metadata,
  213. )
  214. def prepare_input_tensors(
  215. self,
  216. seq_group_metadata_list: List[SequenceGroupMetadata],
  217. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
  218. Optional[torch.Tensor]]:
  219. multi_modal_input = None
  220. if self.is_driver_worker:
  221. # NOTE: We assume that all sequences in the group are all prompts or
  222. # all decodes.
  223. is_prompt = seq_group_metadata_list[0].is_prompt
  224. # Prepare input tensors.
  225. if is_prompt:
  226. (input_tokens, input_positions, attn_metadata, seq_lens,
  227. multi_modal_input
  228. ) = self._prepare_prompt(seq_group_metadata_list)
  229. else:
  230. (input_tokens, input_positions,
  231. attn_metadata) = self._prepare_decode(seq_group_metadata_list)
  232. seq_lens = []
  233. sampling_metadata = SamplingMetadata.prepare(
  234. seq_group_metadata_list,
  235. seq_lens,
  236. # query_lens is not needed if chunked prefill is not
  237. # supported. Since CPU worker doesn't support chunked prefill
  238. # just use seq_lens instead.
  239. seq_lens,
  240. self.device,
  241. pin_memory=False)
  242. # Broadcast the metadata.
  243. metadata_dict = {
  244. "input_tokens": input_tokens,
  245. "input_positions": input_positions,
  246. "selected_token_indices":
  247. sampling_metadata.selected_token_indices,
  248. }
  249. metadata_dict.update(attn_metadata.asdict_zerocopy())
  250. broadcast_tensor_dict(metadata_dict, src=0)
  251. else:
  252. metadata_dict = broadcast_tensor_dict(src=0)
  253. input_tokens = metadata_dict.pop("input_tokens")
  254. input_positions = metadata_dict.pop("input_positions")
  255. selected_token_indices = metadata_dict.pop(
  256. "selected_token_indices")
  257. attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
  258. sampling_metadata = SamplingMetadata(
  259. seq_groups=None,
  260. seq_data=None,
  261. seq_lens=None,
  262. selected_token_indices=selected_token_indices,
  263. categorized_sample_indices=None,
  264. generators=None,
  265. )
  266. return (input_tokens, input_positions, attn_metadata,
  267. sampling_metadata, multi_modal_input)
  268. @torch.inference_mode()
  269. def execute_model(
  270. self,
  271. seq_group_metadata_list: List[SequenceGroupMetadata],
  272. kv_caches: List[torch.Tensor],
  273. ) -> Optional[SamplerOutput]:
  274. (input_tokens, input_positions, attn_metadata, sampling_metadata,
  275. multi_modal_input
  276. ) = self.prepare_input_tensors(seq_group_metadata_list)
  277. model_executable = self.model
  278. execute_model_kwargs = {
  279. "input_ids": input_tokens,
  280. "positions": input_positions,
  281. "kv_caches": kv_caches,
  282. "attn_metadata": attn_metadata,
  283. }
  284. if self.vision_language_config:
  285. execute_model_kwargs.update({"image_input": multi_modal_input})
  286. hidden_states = model_executable(**execute_model_kwargs)
  287. # Compute the logits.
  288. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  289. # Only perform sampling in the driver worker.
  290. if not self.is_driver_worker:
  291. return None
  292. # Sample the next token.
  293. output = self.model.sample(
  294. logits=logits,
  295. sampling_metadata=sampling_metadata,
  296. )
  297. return output