draft_model_runner.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. from typing import List, Optional
  2. import torch
  3. from loguru import logger
  4. from aphrodite import _custom_ops as ops
  5. try:
  6. from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
  7. except ModuleNotFoundError:
  8. # aphrodite_flash_attn is not installed, use the identical ROCm FA metadata
  9. from aphrodite.attention.backends.rocm_flash_attn import (
  10. ROCmFlashAttentionMetadata as FlashAttentionMetadata)
  11. try:
  12. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  13. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  14. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  15. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  16. except ImportError:
  17. BatchDecodeWithPagedKVCacheWrapper = None
  18. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  19. BatchPrefillWithPagedKVCacheWrapper = None
  20. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  21. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  22. LoRAConfig, ModelConfig, ParallelConfig,
  23. PromptAdapterConfig, SchedulerConfig)
  24. from aphrodite.common.sequence import (ExecuteModelRequest,
  25. IntermediateTensors, SamplerOutput)
  26. from aphrodite.multimodal import MultiModalInputs
  27. from aphrodite.task_handler.model_runner import (
  28. ModelInputForGPUWithSamplingMetadata, ModelRunner)
  29. # A flag to enable debug prints for the updated input tensors
  30. # before each step.
  31. debug_advance_input = False
  32. # A flag to allow GPU advance step for draft model runner.
  33. # Set to False for debugging.
  34. allow_gpu_advance_step = True
  35. class TP1DraftModelRunner(ModelRunner):
  36. """Specialized model runner for speculative decoding draft model.
  37. Since the draft model always execute k forward passes consecutively to
  38. generate k speculative tokens in a single speculative decoding step,
  39. we could get rid of most CPU-GPU synchronization and data transfer
  40. overheads by keeping model input and output tensors on GPU all the time.
  41. TODOs:
  42. 1. Currently supports only flash-attn, add support for other attn_backends.
  43. 2. Support TP > 1 (this requires some designs because we do not expect
  44. any broadcasting inside execute_model).
  45. """
  46. def __init__(
  47. self,
  48. model_config: ModelConfig,
  49. parallel_config: ParallelConfig,
  50. scheduler_config: SchedulerConfig,
  51. device_config: DeviceConfig,
  52. cache_config: CacheConfig,
  53. load_config: LoadConfig,
  54. lora_config: Optional[LoRAConfig],
  55. kv_cache_dtype: Optional[str] = "auto",
  56. is_driver_worker: bool = False,
  57. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  58. return_hidden_states: bool = False,
  59. ):
  60. if return_hidden_states:
  61. raise ValueError(
  62. "return_hidden_states is not supported for TP1DraftModelRunner."
  63. )
  64. super().__init__(
  65. model_config=model_config,
  66. parallel_config=parallel_config,
  67. scheduler_config=scheduler_config,
  68. device_config=device_config,
  69. cache_config=cache_config,
  70. load_config=load_config,
  71. lora_config=lora_config,
  72. kv_cache_dtype=kv_cache_dtype,
  73. is_driver_worker=is_driver_worker,
  74. prompt_adapter_config=prompt_adapter_config,
  75. return_hidden_states=return_hidden_states,
  76. )
  77. self.flashinfer_decode_workspace_buffer = None
  78. self.flashinfer_decode_wrapper = None
  79. self.flashinfer_prefill_workspace_buffer = None
  80. self.flashinfer_prefill_wrapper = None
  81. def _update_sampling_metadata(self, sampling_metadata, num_seqs,
  82. num_queries):
  83. assert sampling_metadata.num_prompts == 0
  84. assert len(sampling_metadata.seq_groups) == num_queries
  85. assert sampling_metadata.selected_token_indices.shape == (
  86. num_queries, )
  87. # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
  88. # Verify that all sequences are decodes
  89. for i in range(num_queries):
  90. seq_group = sampling_metadata.seq_groups[i]
  91. assert seq_group.is_prompt is False # No prompt
  92. assert seq_group.prompt_logprob_indices == [] # No prompt
  93. assert seq_group.sample_indices == [i] # Simple
  94. assert seq_group.seq_len is None # Decode
  95. assert seq_group.query_len is None # Decode
  96. def _gpu_advance_step(
  97. self, model_input: ModelInputForGPUWithSamplingMetadata,
  98. last_output: SamplerOutput
  99. ) -> ModelInputForGPUWithSamplingMetadata:
  100. # Currently, we expect "decode mode" only
  101. assert not model_input.is_prompt
  102. # Get num_seqs
  103. num_seqs = len(model_input.seq_lens)
  104. num_queries = len(model_input.query_lens)
  105. # Get output tokens GPU tensor
  106. sampled_token_ids = last_output.sampled_token_ids
  107. assert sampled_token_ids is not None
  108. # Update attn_metadata
  109. attn_metadata = model_input.attn_metadata
  110. assert isinstance(attn_metadata, FlashAttentionMetadata)
  111. attn_metadata.advance_step(num_seqs, num_queries)
  112. # Update GPU tensors
  113. ops.advance_step(num_seqs=num_seqs,
  114. num_queries=num_queries,
  115. block_size=self.block_size,
  116. input_tokens=model_input.input_tokens,
  117. sampled_token_ids=sampled_token_ids,
  118. input_positions=model_input.input_positions,
  119. seq_lens=attn_metadata.seq_lens_tensor,
  120. slot_mapping=attn_metadata.slot_mapping,
  121. block_tables=attn_metadata.block_tables)
  122. # Update sampling_metadata
  123. sampling_metadata = model_input.sampling_metadata
  124. self._update_sampling_metadata(sampling_metadata, num_seqs,
  125. num_queries)
  126. # Create new input
  127. new_model_input = self._model_input_cls(
  128. input_tokens=model_input.input_tokens,
  129. input_positions=model_input.input_positions,
  130. attn_metadata=attn_metadata,
  131. seq_lens=attn_metadata.seq_lens,
  132. query_lens=model_input.query_lens,
  133. lora_mapping=model_input.lora_mapping,
  134. lora_requests=model_input.lora_requests,
  135. multi_modal_kwargs=model_input.multi_modal_kwargs,
  136. sampling_metadata=model_input.sampling_metadata,
  137. is_prompt=False,
  138. )
  139. # Ensure we skip CPU samples
  140. assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
  141. # We can reuse sampling tensors since every decode iteration is the same
  142. new_model_input.sampling_metadata.reuse_sampling_tensors = True
  143. if debug_advance_input:
  144. logger.debug("NEW INPUT: ")
  145. logger.debug(f" input_tokens = {new_model_input.input_tokens}")
  146. logger.debug(" input_positions = "
  147. f"{new_model_input.input_positions}")
  148. logger.debug(f" seq_lens = {new_model_input.seq_lens}")
  149. logger.debug(f" query_lens = {new_model_input.query_lens}")
  150. logger.debug(" attn_metadata:")
  151. logger.debug(" seq_lens_tensor: "
  152. f"{attn_metadata.seq_lens_tensor}")
  153. logger.debug(f" slot_mapping: {attn_metadata.slot_mapping}")
  154. logger.debug(f" block_tables: {attn_metadata.block_tables}")
  155. return new_model_input
  156. def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
  157. """Determines if draft_model_runner GPU multi-step can be used.
  158. Currently required conditions are:
  159. 1. Only decodes
  160. 2. Only flash-attn
  161. 3. No LORA
  162. 4. No prompt_adapter_config
  163. """
  164. if not allow_gpu_advance_step:
  165. return False
  166. # We allow multi-step GPU only in decode mode
  167. for seq_group in execute_model_req.seq_group_metadata_list:
  168. if seq_group.is_prompt:
  169. return False
  170. # TODO: Add support for other attn backends
  171. if self.attn_backend.get_name() != "flash-attn":
  172. return False
  173. # TODO: Add support for LORA
  174. if self.lora_config:
  175. return False
  176. # TODO: Add soft-tuning prompt adapter support
  177. if self.prompt_adapter_config:
  178. return False
  179. return True
  180. @torch.inference_mode()
  181. def execute_model(
  182. self,
  183. model_input: ModelInputForGPUWithSamplingMetadata,
  184. kv_caches: List[torch.Tensor],
  185. intermediate_tensors: Optional[IntermediateTensors] = None,
  186. num_steps: int = 1,
  187. ) -> Optional[List[SamplerOutput]]:
  188. """Executes num_steps forward passes with advacement of input tensors
  189. on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
  190. Optimizations used:
  191. 1. Input tensors are updated on the GPU directly
  192. 2. Skips GPU=>CPU serialization of sampler outputs (we don't need
  193. them since we do batch expansion later that uses GPU outputs)
  194. 3. Reuses sampling tensors (since we run only decodes and they have
  195. a repeating sampling logic)
  196. """
  197. # When num_steps == 1, we execute the fallback here for the GPU
  198. # advance_step, which runs prepare_inputs on CPU and for each spec
  199. # iteration invokes this function only once
  200. # (Look at multi-step-worker code)
  201. is_fallback = num_steps == 1
  202. if not is_fallback:
  203. # Since we do not broadcast data inside execute_model anymore,
  204. # we need to figure out the best way to support TP > 1 in this
  205. # case, because we will at least need to broadcast the sampled
  206. # tokens to all workers.
  207. if not self.is_driver_worker:
  208. raise ValueError("TP1DraftModelRunner only supports TP=1.")
  209. # Sanity
  210. if self.lora_config is not None:
  211. raise ValueError("TP1DraftModelRunner has no support for LORA")
  212. if self.prompt_adapter_config is not None:
  213. raise ValueError("TP1DraftModelRunner has no support for "
  214. "prompt_adapter_config")
  215. if model_input.multi_modal_kwargs:
  216. raise ValueError(
  217. "TP1DraftModelRunner has no support for multi_modal_kwargs"
  218. )
  219. else:
  220. if self.lora_config:
  221. assert model_input.lora_requests is not None
  222. assert model_input.lora_mapping is not None
  223. self.set_active_loras(model_input.lora_requests,
  224. model_input.lora_mapping)
  225. if self.prompt_adapter_config:
  226. assert model_input.prompt_adapter_requests is not None
  227. assert model_input.prompt_adapter_mapping is not None
  228. self.set_active_prompt_adapters(
  229. model_input.prompt_adapter_requests,
  230. model_input.prompt_adapter_mapping)
  231. if self.attn_backend.get_name() == "flashinfer":
  232. assert model_input.attn_metadata is not None
  233. assert model_input.input_tokens is not None
  234. if self.flashinfer_decode_workspace_buffer is None:
  235. self.flashinfer_decode_workspace_buffer = torch.empty(
  236. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  237. dtype=torch.uint8,
  238. device=self.device)
  239. self.flashinfer_decode_wrapper = \
  240. BatchDecodeWithPagedKVCacheWrapper(
  241. self.flashinfer_decode_workspace_buffer, "NHD")
  242. self.flashinfer_prefill_workspace_buffer = torch.empty(
  243. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  244. dtype=torch.uint8,
  245. device=self.device)
  246. self.flashinfer_prefill_wrapper = \
  247. BatchPrefillWithPagedKVCacheWrapper(
  248. self.flashinfer_prefill_workspace_buffer, "NHD")
  249. model_input.attn_metadata.prefill_wrapper = \
  250. self.flashinfer_prefill_wrapper
  251. if model_input.attn_metadata.use_cuda_graph:
  252. batch_size = model_input.input_tokens.shape[0]
  253. model_input.attn_metadata.decode_wrapper = \
  254. self.graph_runners[model_input.
  255. virtual_engine][batch_size].flashinfer_decode_wrapper
  256. else:
  257. model_input.attn_metadata.decode_wrapper = \
  258. self.flashinfer_decode_wrapper
  259. model_input.attn_metadata.begin_forward()
  260. # Detect exec mode
  261. assert model_input.attn_metadata is not None
  262. use_cuda_graph = False
  263. if model_input.attn_metadata.num_prefills > 0:
  264. # In this case, execute_model(..) was called directly
  265. if num_steps > 1:
  266. raise ValueError(
  267. "execute_model(..) of draft_model_runner can be called "
  268. "directly only with a single-step prefill")
  269. else:
  270. # We can skip CPU samples for spec token generation.
  271. # (We do allow CPU samples for num_steps == 1 to support the
  272. # fallback case, where supports_gpu_multi_step(..) does not pass)
  273. model_input.sampling_metadata.skip_sampler_cpu_output = (
  274. not is_fallback)
  275. # Attn attr defines if we use cuda graphs
  276. use_cuda_graph = model_input.attn_metadata.use_cuda_graph
  277. # Get model
  278. if use_cuda_graph:
  279. graph_batch_size = model_input.input_tokens.shape[0]
  280. model_executable = (self.graph_runners[model_input.virtual_engine]
  281. [graph_batch_size])
  282. else:
  283. model_executable = self.model
  284. outputs: List[SamplerOutput] = []
  285. for step in range(num_steps):
  286. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  287. # Run model
  288. hidden_states = model_executable(
  289. input_ids=model_input.input_tokens,
  290. positions=model_input.input_positions,
  291. kv_caches=kv_caches,
  292. attn_metadata=model_input.attn_metadata,
  293. intermediate_tensors=intermediate_tensors,
  294. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  295. device=self.device),
  296. )
  297. # Compute the logits.
  298. logits = self.model.compute_logits(hidden_states,
  299. model_input.sampling_metadata)
  300. # Sample the next token.
  301. outputs.append(
  302. self.model.sample(
  303. logits=logits,
  304. sampling_metadata=model_input.sampling_metadata,
  305. ))
  306. # Prepare inputs for the next step
  307. if step != num_steps - 1:
  308. model_input = self._gpu_advance_step(model_input, outputs[-1])
  309. return outputs