draft_model_runner.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. from typing import List, Optional
  2. import torch
  3. from loguru import logger
  4. from aphrodite import _custom_ops as ops
  5. try:
  6. from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
  7. except ModuleNotFoundError:
  8. # aphrodite_flash_attn is not installed, use the identical ROCm FA metadata
  9. from aphrodite.attention.backends.rocm_flash_attn import (
  10. ROCmFlashAttentionMetadata as FlashAttentionMetadata)
  11. try:
  12. from flashinfer import BatchDecodeWithPagedKVCacheWrapper
  13. from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
  14. from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
  15. FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
  16. except ImportError:
  17. BatchDecodeWithPagedKVCacheWrapper = None
  18. CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
  19. BatchPrefillWithPagedKVCacheWrapper = None
  20. FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
  21. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  22. LoRAConfig, ModelConfig, MultiModalConfig,
  23. ParallelConfig, PromptAdapterConfig,
  24. SchedulerConfig)
  25. from aphrodite.common.sequence import (ExecuteModelRequest,
  26. IntermediateTensors, SamplerOutput)
  27. from aphrodite.multimodal import MultiModalInputs
  28. from aphrodite.task_handler.model_runner import (
  29. ModelInputForGPUWithSamplingMetadata, ModelRunner)
  30. # A flag to enable debug prints for the updated input tensors
  31. # before each step.
  32. debug_advance_input = False
  33. # A flag to allow GPU advance step for draft model runner.
  34. # Set to False for debugging.
  35. allow_gpu_advance_step = True
  36. class TP1DraftModelRunner(ModelRunner):
  37. """Specialized model runner for speculative decoding draft model.
  38. Since the draft model always execute k forward passes consecutively to
  39. generate k speculative tokens in a single speculative decoding step,
  40. we could get rid of most CPU-GPU synchronization and data transfer
  41. overheads by keeping model input and output tensors on GPU all the time.
  42. TODOs:
  43. 1. Currently supports only flash-attn, add support for other attn_backends.
  44. 2. Support TP > 1 (this requires some designs because we do not expect
  45. any broadcasting inside execute_model).
  46. """
  47. def __init__(
  48. self,
  49. model_config: ModelConfig,
  50. parallel_config: ParallelConfig,
  51. scheduler_config: SchedulerConfig,
  52. device_config: DeviceConfig,
  53. cache_config: CacheConfig,
  54. load_config: LoadConfig,
  55. lora_config: Optional[LoRAConfig],
  56. kv_cache_dtype: Optional[str] = "auto",
  57. is_driver_worker: bool = False,
  58. multimodal_config: Optional[MultiModalConfig] = None,
  59. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  60. return_hidden_states: bool = False,
  61. **kwargs,
  62. ):
  63. if return_hidden_states:
  64. raise ValueError(
  65. "return_hidden_states is not supported for TP1DraftModelRunner."
  66. )
  67. super().__init__(
  68. model_config=model_config,
  69. parallel_config=parallel_config,
  70. scheduler_config=scheduler_config,
  71. device_config=device_config,
  72. cache_config=cache_config,
  73. load_config=load_config,
  74. lora_config=lora_config,
  75. kv_cache_dtype=kv_cache_dtype,
  76. is_driver_worker=is_driver_worker,
  77. multimodal_config=multimodal_config,
  78. prompt_adapter_config=prompt_adapter_config,
  79. return_hidden_states=return_hidden_states,
  80. **kwargs, # needed for uneven TP
  81. )
  82. self.flashinfer_decode_workspace_buffer = None
  83. self.flashinfer_decode_wrapper = None
  84. self.flashinfer_prefill_workspace_buffer = None
  85. self.flashinfer_prefill_wrapper = None
  86. def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
  87. num_queries):
  88. assert isinstance(attn_metadata, FlashAttentionMetadata)
  89. if num_seqs != num_queries:
  90. assert num_seqs > num_queries
  91. assert attn_metadata.use_cuda_graph
  92. assert attn_metadata.num_prefills == 0
  93. assert attn_metadata.num_prefill_tokens == 0
  94. assert attn_metadata.num_decode_tokens == num_seqs
  95. assert attn_metadata.slot_mapping.shape == (num_seqs, )
  96. assert len(attn_metadata.seq_lens) == num_seqs
  97. assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
  98. assert attn_metadata.max_query_len == 1
  99. assert attn_metadata.max_prefill_seq_len == 0
  100. assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
  101. assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
  102. assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
  103. assert attn_metadata.context_lens_tensor.shape == (num_queries, )
  104. assert attn_metadata.block_tables.shape[0] == num_seqs
  105. # Update query lengths. Note that we update only queries and not seqs,
  106. # since tensors may be padded due to captured cuda graph batch size
  107. for i in range(num_queries):
  108. attn_metadata.seq_lens[i] += 1
  109. attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
  110. def _update_sampling_metadata(self, sampling_metadata, num_seqs,
  111. num_queries):
  112. assert sampling_metadata.num_prompts == 0
  113. assert len(sampling_metadata.seq_groups) == num_queries
  114. assert sampling_metadata.selected_token_indices.shape == (
  115. num_queries, )
  116. # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
  117. # Verify that all sequences are decodes
  118. for i in range(num_queries):
  119. seq_group = sampling_metadata.seq_groups[i]
  120. assert seq_group.is_prompt is False # No prompt
  121. assert seq_group.prompt_logprob_indices == [] # No prompt
  122. assert seq_group.sample_indices == [i] # Simple
  123. assert seq_group.seq_len is None # Decode
  124. assert seq_group.query_len is None # Decode
  125. def _gpu_advance_step(
  126. self, model_input: ModelInputForGPUWithSamplingMetadata,
  127. last_output: SamplerOutput
  128. ) -> ModelInputForGPUWithSamplingMetadata:
  129. # Currently, we expect "decode mode" only
  130. assert not model_input.is_prompt
  131. # Get num_seqs
  132. num_seqs = len(model_input.seq_lens)
  133. num_queries = len(model_input.query_lens)
  134. # Get output tokens GPU tensor
  135. sampled_token_ids = last_output.sampled_token_ids
  136. assert sampled_token_ids is not None
  137. # Update attn_metadata
  138. attn_metadata = model_input.attn_metadata
  139. assert isinstance(attn_metadata, FlashAttentionMetadata)
  140. self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
  141. # Update GPU tensors
  142. ops.advance_step(num_seqs=num_seqs,
  143. num_queries=num_queries,
  144. block_size=self.block_size,
  145. input_tokens=model_input.input_tokens,
  146. sampled_token_ids=sampled_token_ids,
  147. input_positions=model_input.input_positions,
  148. seq_lens=attn_metadata.seq_lens_tensor,
  149. slot_mapping=attn_metadata.slot_mapping,
  150. block_tables=attn_metadata.block_tables)
  151. # Update sampling_metadata
  152. sampling_metadata = model_input.sampling_metadata
  153. self._update_sampling_metadata(sampling_metadata, num_seqs,
  154. num_queries)
  155. # Create new input
  156. new_model_input = self._model_input_cls(
  157. input_tokens=model_input.input_tokens,
  158. input_positions=model_input.input_positions,
  159. attn_metadata=attn_metadata,
  160. seq_lens=attn_metadata.seq_lens,
  161. query_lens=model_input.query_lens,
  162. lora_mapping=model_input.lora_mapping,
  163. lora_requests=model_input.lora_requests,
  164. multi_modal_kwargs=model_input.multi_modal_kwargs,
  165. sampling_metadata=model_input.sampling_metadata,
  166. is_prompt=False,
  167. )
  168. # Ensure we skip CPU samples
  169. assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
  170. # We can reuse sampling tensors since every decode iteration is the same
  171. new_model_input.sampling_metadata.reuse_sampling_tensors = True
  172. if debug_advance_input:
  173. logger.debug("NEW INPUT: ")
  174. logger.debug(f" input_tokens = {new_model_input.input_tokens}")
  175. logger.debug(" input_positions = "
  176. f"{new_model_input.input_positions}")
  177. logger.debug(f" seq_lens = {new_model_input.seq_lens}")
  178. logger.debug(f" query_lens = {new_model_input.query_lens}")
  179. logger.debug(" attn_metadata:")
  180. logger.debug(" seq_lens_tensor: "
  181. f"{attn_metadata.seq_lens_tensor}")
  182. logger.debug(f" slot_mapping: {attn_metadata.slot_mapping}")
  183. logger.debug(f" block_tables: {attn_metadata.block_tables}")
  184. return new_model_input
  185. def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
  186. """Determines if draft_model_runner GPU multi-step can be used.
  187. Currently required conditions are:
  188. 1. Only decodes
  189. 2. Only flash-attn
  190. 3. No LORA
  191. 4. No prompt_adapter_config
  192. """
  193. if not allow_gpu_advance_step:
  194. return False
  195. # We allow multi-step GPU only in decode mode
  196. for seq_group in execute_model_req.seq_group_metadata_list:
  197. if seq_group.is_prompt:
  198. return False
  199. # TODO: Add support for other attn backends
  200. if self.attn_backend.get_name() != "flash-attn":
  201. return False
  202. # TODO: Add support for LORA
  203. if self.lora_config:
  204. return False
  205. # TODO: Add soft-tuning prompt adapter support
  206. if self.prompt_adapter_config:
  207. return False
  208. return True
  209. @torch.inference_mode()
  210. def execute_model(
  211. self,
  212. model_input: ModelInputForGPUWithSamplingMetadata,
  213. kv_caches: List[torch.Tensor],
  214. intermediate_tensors: Optional[IntermediateTensors] = None,
  215. num_steps: int = 1,
  216. ) -> Optional[List[SamplerOutput]]:
  217. """Executes num_steps forward passes with advacement of input tensors
  218. on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
  219. Optimizations used:
  220. 1. Input tensors are updated on the GPU directly
  221. 2. Skips GPU=>CPU serialization of sampler outputs (we don't need
  222. them since we do batch expansion later that uses GPU outputs)
  223. 3. Reuses sampling tensors (since we run only decodes and they have
  224. a repeating sampling logic)
  225. """
  226. # When num_steps == 1, we execute the fallback here for the GPU
  227. # advance_step, which runs prepare_inputs on CPU and for each spec
  228. # iteration invokes this function only once
  229. # (Look at multi-step-worker code)
  230. is_fallback = num_steps == 1
  231. if not is_fallback:
  232. # Since we do not broadcast data inside execute_model anymore,
  233. # we need to figure out the best way to support TP > 1 in this
  234. # case, because we will at least need to broadcast the sampled
  235. # tokens to all workers.
  236. if not self.is_driver_worker:
  237. raise ValueError("TP1DraftModelRunner only supports TP=1.")
  238. # Sanity
  239. if self.lora_config is not None:
  240. raise ValueError("TP1DraftModelRunner has no support for LORA")
  241. if self.prompt_adapter_config is not None:
  242. raise ValueError("TP1DraftModelRunner has no support for "
  243. "prompt_adapter_config")
  244. if model_input.multi_modal_kwargs:
  245. raise ValueError(
  246. "TP1DraftModelRunner has no support for multi_modal_kwargs"
  247. )
  248. else:
  249. if self.lora_config:
  250. assert model_input.lora_requests is not None
  251. assert model_input.lora_mapping is not None
  252. self.set_active_loras(model_input.lora_requests,
  253. model_input.lora_mapping)
  254. if self.prompt_adapter_config:
  255. assert model_input.prompt_adapter_requests is not None
  256. assert model_input.prompt_adapter_mapping is not None
  257. self.set_active_prompt_adapters(
  258. model_input.prompt_adapter_requests,
  259. model_input.prompt_adapter_mapping)
  260. if self.attn_backend.get_name() == "flashinfer":
  261. assert model_input.attn_metadata is not None
  262. assert model_input.input_tokens is not None
  263. if self.flashinfer_decode_workspace_buffer is None:
  264. self.flashinfer_decode_workspace_buffer = torch.empty(
  265. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  266. dtype=torch.uint8,
  267. device=self.device)
  268. self.flashinfer_decode_wrapper = \
  269. BatchDecodeWithPagedKVCacheWrapper(
  270. self.flashinfer_decode_workspace_buffer, "NHD")
  271. self.flashinfer_prefill_workspace_buffer = torch.empty(
  272. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  273. dtype=torch.uint8,
  274. device=self.device)
  275. self.flashinfer_prefill_wrapper = \
  276. BatchPrefillWithPagedKVCacheWrapper(
  277. self.flashinfer_prefill_workspace_buffer, "NHD")
  278. model_input.attn_metadata.prefill_wrapper = \
  279. self.flashinfer_prefill_wrapper
  280. if model_input.attn_metadata.use_cuda_graph:
  281. batch_size = model_input.input_tokens.shape[0]
  282. model_input.attn_metadata.decode_wrapper = \
  283. self.graph_runners[model_input.
  284. virtual_engine][batch_size].flashinfer_decode_wrapper
  285. else:
  286. model_input.attn_metadata.decode_wrapper = \
  287. self.flashinfer_decode_wrapper
  288. model_input.attn_metadata.begin_forward()
  289. # Detect exec mode
  290. assert model_input.attn_metadata is not None
  291. use_cuda_graph = False
  292. if model_input.attn_metadata.num_prefills > 0:
  293. # In this case, execute_model(..) was called directly
  294. if num_steps > 1:
  295. raise ValueError(
  296. "execute_model(..) of draft_model_runner can be called "
  297. "directly only with a single-step prefill")
  298. else:
  299. # We can skip CPU samples for spec token generation.
  300. # (We do allow CPU samples for num_steps == 1 to support the
  301. # fallback case, where supports_gpu_multi_step(..) does not pass)
  302. model_input.sampling_metadata.skip_sampler_cpu_output = (
  303. not is_fallback)
  304. # Attn attr defines if we use cuda graphs
  305. use_cuda_graph = model_input.attn_metadata.use_cuda_graph
  306. # Get model
  307. if use_cuda_graph:
  308. graph_batch_size = model_input.input_tokens.shape[0]
  309. model_executable = (self.graph_runners[model_input.virtual_engine]
  310. [graph_batch_size])
  311. else:
  312. model_executable = self.model
  313. outputs: List[SamplerOutput] = []
  314. for step in range(num_steps):
  315. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  316. # Run model
  317. hidden_states = model_executable(
  318. input_ids=model_input.input_tokens,
  319. positions=model_input.input_positions,
  320. kv_caches=kv_caches,
  321. attn_metadata=model_input.attn_metadata,
  322. intermediate_tensors=intermediate_tensors,
  323. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  324. device=self.device),
  325. )
  326. # Compute the logits.
  327. logits = self.model.compute_logits(hidden_states,
  328. model_input.sampling_metadata)
  329. # Sample the next token.
  330. outputs.append(
  331. self.model.sample(
  332. logits=logits,
  333. sampling_metadata=model_input.sampling_metadata,
  334. ))
  335. # Prepare inputs for the next step
  336. if step != num_steps - 1:
  337. model_input = self._gpu_advance_step(model_input, outputs[-1])
  338. return outputs