draft_model_runner.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. from typing import List, Optional
  2. import torch
  3. from loguru import logger
  4. from aphrodite import _custom_ops as ops
  5. try:
  6. from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
  7. except ModuleNotFoundError:
  8. # vllm_flash_attn is not installed, use the identical ROCm FA metadata
  9. from aphrodite.attention.backends.rocm_flash_attn import (
  10. ROCmFlashAttentionMetadata as FlashAttentionMetadata)
  11. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  12. LoRAConfig, ModelConfig, MultiModalConfig,
  13. ParallelConfig, PromptAdapterConfig,
  14. SchedulerConfig)
  15. from aphrodite.common.sequence import (ExecuteModelRequest,
  16. IntermediateTensors, SamplerOutput)
  17. from aphrodite.task_handler.model_runner import (
  18. ModelInputForGPUWithSamplingMetadata, ModelRunner)
  19. # A flag to enable debug prints for the updated input tensors
  20. # before each step.
  21. debug_advance_input = False
  22. # A flag to allow GPU advance step for draft model runner.
  23. # Set to False for debugging.
  24. allow_gpu_advance_step = True
  25. class TP1DraftModelRunner(ModelRunner):
  26. """Specialized model runner for speculative decoding draft model.
  27. Since the draft model always execute k forward passes consecutively to
  28. generate k speculative tokens in a single speculative decoding step,
  29. we could get rid of most CPU-GPU synchronization and data transfer
  30. overheads by keeping model input and output tensors on GPU all the time.
  31. TODOs:
  32. 1. Currently supports only flash-attn, add support for other attn_backends.
  33. 2. Support TP > 1 (this requires some designs because we do not expect
  34. any broadcasting inside execute_model).
  35. """
  36. def __init__(
  37. self,
  38. model_config: ModelConfig,
  39. parallel_config: ParallelConfig,
  40. scheduler_config: SchedulerConfig,
  41. device_config: DeviceConfig,
  42. cache_config: CacheConfig,
  43. load_config: LoadConfig,
  44. lora_config: Optional[LoRAConfig],
  45. kv_cache_dtype: Optional[str] = "auto",
  46. is_driver_worker: bool = False,
  47. multimodal_config: Optional[MultiModalConfig] = None,
  48. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  49. return_hidden_states: bool = False,
  50. **kwargs,
  51. ):
  52. if return_hidden_states:
  53. raise ValueError(
  54. "return_hidden_states is not supported for TP1DraftModelRunner."
  55. )
  56. super().__init__(
  57. model_config=model_config,
  58. parallel_config=parallel_config,
  59. scheduler_config=scheduler_config,
  60. device_config=device_config,
  61. cache_config=cache_config,
  62. load_config=load_config,
  63. lora_config=lora_config,
  64. kv_cache_dtype=kv_cache_dtype,
  65. is_driver_worker=is_driver_worker,
  66. multimodal_config=multimodal_config,
  67. prompt_adapter_config=prompt_adapter_config,
  68. return_hidden_states=return_hidden_states,
  69. **kwargs,
  70. )
  71. def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
  72. num_queries):
  73. assert isinstance(attn_metadata, FlashAttentionMetadata)
  74. if num_seqs != num_queries:
  75. assert num_seqs > num_queries
  76. assert attn_metadata.use_cuda_graph
  77. assert attn_metadata.num_prefills == 0
  78. assert attn_metadata.num_prefill_tokens == 0
  79. assert attn_metadata.num_decode_tokens == num_seqs
  80. assert attn_metadata.slot_mapping.shape == (num_seqs, )
  81. assert len(attn_metadata.seq_lens) == num_seqs
  82. assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
  83. assert attn_metadata.max_query_len == 1
  84. assert attn_metadata.max_prefill_seq_len == 0
  85. assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
  86. assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
  87. assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
  88. assert attn_metadata.context_lens_tensor.shape == (num_queries, )
  89. assert attn_metadata.block_tables.shape[0] == num_seqs
  90. # Update query lengths. Note that we update only queries and not seqs,
  91. # since tensors may be padded due to captured cuda graph batch size
  92. for i in range(num_queries):
  93. attn_metadata.seq_lens[i] += 1
  94. attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
  95. def _update_sampling_metadata(self, sampling_metadata, num_seqs,
  96. num_queries):
  97. assert sampling_metadata.num_prompts == 0
  98. assert len(sampling_metadata.seq_groups) == num_queries
  99. assert sampling_metadata.selected_token_indices.shape == (
  100. num_queries, )
  101. # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
  102. # Verify that all sequences are decodes
  103. for i in range(num_queries):
  104. seq_group = sampling_metadata.seq_groups[i]
  105. assert seq_group.is_prompt is False # No prompt
  106. assert seq_group.prompt_logprob_indices == [] # No prompt
  107. assert seq_group.sample_indices == [i] # Simple
  108. assert seq_group.seq_len is None # Decode
  109. assert seq_group.query_len is None # Decode
  110. def _gpu_advance_step(
  111. self, model_input: ModelInputForGPUWithSamplingMetadata,
  112. last_output: SamplerOutput
  113. ) -> ModelInputForGPUWithSamplingMetadata:
  114. # Currently, we expect "decode mode" only
  115. assert not model_input.is_prompt
  116. # Get num_seqs
  117. num_seqs = len(model_input.seq_lens)
  118. num_queries = len(model_input.query_lens)
  119. # Get output tokens GPU tensor
  120. sampled_token_ids = last_output.sampled_token_ids
  121. assert sampled_token_ids is not None
  122. # Update attn_metadata
  123. attn_metadata = model_input.attn_metadata
  124. assert isinstance(attn_metadata, FlashAttentionMetadata)
  125. self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
  126. # Update GPU tensors
  127. ops.advance_step(num_seqs=num_seqs,
  128. num_queries=num_queries,
  129. block_size=self.block_size,
  130. input_tokens=model_input.input_tokens,
  131. sampled_token_ids=sampled_token_ids,
  132. input_positions=model_input.input_positions,
  133. seq_lens=attn_metadata.seq_lens_tensor,
  134. slot_mapping=attn_metadata.slot_mapping,
  135. block_tables=attn_metadata.block_tables)
  136. # Update sampling_metadata
  137. sampling_metadata = model_input.sampling_metadata
  138. self._update_sampling_metadata(sampling_metadata, num_seqs,
  139. num_queries)
  140. # Create new input
  141. new_model_input = self._model_input_cls(
  142. input_tokens=model_input.input_tokens,
  143. input_positions=model_input.input_positions,
  144. attn_metadata=attn_metadata,
  145. seq_lens=attn_metadata.seq_lens,
  146. query_lens=model_input.query_lens,
  147. lora_mapping=model_input.lora_mapping,
  148. lora_requests=model_input.lora_requests,
  149. multi_modal_kwargs=model_input.multi_modal_kwargs,
  150. sampling_metadata=model_input.sampling_metadata,
  151. is_prompt=False,
  152. )
  153. # Ensure we skip CPU samples
  154. assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
  155. # We can reuse sampling tensors since every decode iteration is the same
  156. new_model_input.sampling_metadata.reuse_sampling_tensors = True
  157. if debug_advance_input:
  158. logger.debug("NEW INPUT: ")
  159. logger.debug(f" input_tokens = {new_model_input.input_tokens}")
  160. logger.debug(" input_positions = "
  161. f"{new_model_input.input_positions}")
  162. logger.debug(f" seq_lens = {new_model_input.seq_lens}")
  163. logger.debug(f" query_lens = {new_model_input.query_lens}")
  164. logger.debug(" attn_metadata:")
  165. logger.debug(" seq_lens_tensor: "
  166. f"{attn_metadata.seq_lens_tensor}")
  167. logger.debug(f" slot_mapping: {attn_metadata.slot_mapping}")
  168. logger.debug(f" block_tables: {attn_metadata.block_tables}")
  169. return new_model_input
  170. def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
  171. """Determines if draft_model_runner GPU multi-step can be used.
  172. Currently required conditions are:
  173. 1. Only decodes
  174. 2. Only flash-attn
  175. 3. No LORA
  176. 4. No prompt_adapter_config
  177. """
  178. if not allow_gpu_advance_step:
  179. return False
  180. # We allow multi-step GPU only in decode mode
  181. for seq_group in execute_model_req.seq_group_metadata_list:
  182. if seq_group.is_prompt:
  183. return False
  184. # TODO: Add support for other attn backends
  185. if self.attn_backend.get_name() != "flash-attn":
  186. return False
  187. # TODO: Add support for LORA
  188. if self.lora_config:
  189. return False
  190. # TODO: Add soft-tuning prompt adapter support
  191. if self.prompt_adapter_config:
  192. return False
  193. return True
  194. @torch.inference_mode()
  195. def execute_model(
  196. self,
  197. model_input: ModelInputForGPUWithSamplingMetadata,
  198. kv_caches: List[torch.Tensor],
  199. intermediate_tensors: Optional[IntermediateTensors] = None,
  200. num_steps: int = 1,
  201. ) -> Optional[List[SamplerOutput]]:
  202. """Executes num_steps forward passes with advacement of input tensors
  203. on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
  204. Optimizations used:
  205. 1. Input tensors are updated on the GPU directly
  206. 2. Skips GPU=>CPU serialization of sampler outputs (we don't need
  207. them since we do batch expansion later that uses GPU outputs)
  208. 3. Reuses sampling tensors (since we run only decodes and they have
  209. a repeating sampling logic)
  210. """
  211. # When num_steps == 1, we execute the fallback here for the GPU
  212. # advance_step, which runs prepare_inputs on CPU and for each spec
  213. # iteration invokes this function only once
  214. # (Look at multi-step-worker code)
  215. is_fallback = num_steps == 1
  216. if not is_fallback:
  217. # Since we do not broadcast data inside execute_model anymore,
  218. # we need to figure out the best way to support TP > 1 in this
  219. # case, because we will at least need to broadcast the sampled
  220. # tokens to all workers.
  221. if not self.is_driver_worker:
  222. raise ValueError("TP1DraftModelRunner only supports TP=1.")
  223. # Sanity
  224. if self.lora_config is not None:
  225. raise ValueError("TP1DraftModelRunner has no support for LORA")
  226. if self.prompt_adapter_config is not None:
  227. raise ValueError("TP1DraftModelRunner has no support for "
  228. "prompt_adapter_config")
  229. if model_input.multi_modal_kwargs:
  230. raise ValueError(
  231. "TP1DraftModelRunner has no support for multi_modal_kwargs"
  232. )
  233. else:
  234. if self.lora_config:
  235. assert model_input.lora_requests is not None
  236. assert model_input.lora_mapping is not None
  237. self.set_active_loras(model_input.lora_requests,
  238. model_input.lora_mapping)
  239. if self.prompt_adapter_config:
  240. assert model_input.prompt_adapter_requests is not None
  241. assert model_input.prompt_adapter_mapping is not None
  242. self.set_active_prompt_adapters(
  243. model_input.prompt_adapter_requests,
  244. model_input.prompt_adapter_mapping)
  245. # Detect exec mode
  246. assert model_input.attn_metadata is not None
  247. use_cuda_graph = False
  248. if model_input.attn_metadata.num_prefills > 0:
  249. # In this case, execute_model(..) was called directly
  250. if num_steps > 1:
  251. raise ValueError(
  252. "execute_model(..) of draft_model_runner can be called "
  253. "directly only with a single-step prefill")
  254. else:
  255. # We can skip CPU samples for spec token generation.
  256. # (We do allow CPU samples for num_steps == 1 to support the
  257. # fallback case, where supports_gpu_multi_step(..) does not pass)
  258. model_input.sampling_metadata.skip_sampler_cpu_output = (
  259. not is_fallback)
  260. # Attn attr defines if we use cuda graphs
  261. use_cuda_graph = model_input.attn_metadata.use_cuda_graph
  262. # Get model
  263. if use_cuda_graph:
  264. graph_batch_size = model_input.input_tokens.shape[0]
  265. model_executable = (self.graph_runners[model_input.virtual_engine]
  266. [graph_batch_size])
  267. else:
  268. model_executable = self.model
  269. outputs: List[SamplerOutput] = []
  270. for step in range(num_steps):
  271. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  272. # Run model
  273. hidden_states = model_executable(
  274. input_ids=model_input.input_tokens,
  275. positions=model_input.input_positions,
  276. kv_caches=kv_caches,
  277. attn_metadata=model_input.attn_metadata,
  278. intermediate_tensors=intermediate_tensors,
  279. **multi_modal_kwargs,
  280. )
  281. # Compute the logits.
  282. logits = self.model.compute_logits(hidden_states,
  283. model_input.sampling_metadata)
  284. # Sample the next token.
  285. outputs.append(
  286. self.model.sample(
  287. logits=logits,
  288. sampling_metadata=model_input.sampling_metadata,
  289. ))
  290. # Prepare inputs for the next step
  291. if step != num_steps - 1:
  292. model_input = self._gpu_advance_step(model_input, outputs[-1])
  293. return outputs