draft_model_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. from typing import List, Optional
  2. import torch
  3. from loguru import logger
  4. try:
  5. from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
  6. except ModuleNotFoundError:
  7. # aphrodite_flash_attn is not installed, use the identical ROCm FA metadata
  8. from aphrodite.attention.backends.rocm_flash_attn import (
  9. ROCmFlashAttentionMetadata as FlashAttentionMetadata)
  10. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  11. LoRAConfig, ModelConfig, ParallelConfig,
  12. PromptAdapterConfig, SchedulerConfig)
  13. from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
  14. from aphrodite.modeling.layers.sampler import SamplerOutput
  15. from aphrodite.multimodal import MultiModalInputs
  16. from aphrodite.worker.model_runner import (
  17. ModelInputForGPUWithSamplingMetadata, ModelRunner)
  18. # A flag to enable debug prints for the updated input tensors
  19. # before each step.
  20. debug_advance_input = False
  21. # A flag to allow GPU advance step for draft model runner.
  22. # Set to False for debugging.
  23. allow_gpu_advance_step = True
  24. class TP1DraftModelRunner(ModelRunner):
  25. """Specialized model runner for speculative decoding draft model.
  26. Since the draft model always execute k forward passes consecutively to
  27. generate k speculative tokens in a single speculative decoding step,
  28. we could get rid of most CPU-GPU synchronization and data transfer
  29. overheads by keeping model input and output tensors on GPU all the time.
  30. TODOs:
  31. 1. Currently supports only flash-attn, add support for other attn_backends.
  32. 2. Support TP > 1 (this requires some designs because we do not expect
  33. any broadcasting inside execute_model).
  34. """
  35. def __init__(
  36. self,
  37. model_config: ModelConfig,
  38. parallel_config: ParallelConfig,
  39. scheduler_config: SchedulerConfig,
  40. device_config: DeviceConfig,
  41. cache_config: CacheConfig,
  42. load_config: LoadConfig,
  43. lora_config: Optional[LoRAConfig],
  44. kv_cache_dtype: Optional[str] = "auto",
  45. is_driver_worker: bool = False,
  46. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  47. return_hidden_states: bool = False,
  48. **kwargs, # for uneven TP
  49. ):
  50. if return_hidden_states:
  51. raise ValueError(
  52. "return_hidden_states is not supported for TP1DraftModelRunner."
  53. )
  54. super().__init__(
  55. model_config=model_config,
  56. parallel_config=parallel_config,
  57. scheduler_config=scheduler_config,
  58. device_config=device_config,
  59. cache_config=cache_config,
  60. load_config=load_config,
  61. lora_config=lora_config,
  62. kv_cache_dtype=kv_cache_dtype,
  63. is_driver_worker=is_driver_worker,
  64. prompt_adapter_config=prompt_adapter_config,
  65. return_hidden_states=return_hidden_states,
  66. **kwargs,
  67. )
  68. def _update_sampling_metadata(self, sampling_metadata, num_seqs,
  69. num_queries):
  70. assert sampling_metadata.num_prompts == 0
  71. assert len(sampling_metadata.seq_groups) == num_queries
  72. assert sampling_metadata.selected_token_indices.shape == (
  73. num_queries, )
  74. # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
  75. # Verify that all sequences are decodes
  76. for i in range(num_queries):
  77. seq_group = sampling_metadata.seq_groups[i]
  78. assert seq_group.is_prompt is False # No prompt
  79. assert seq_group.prompt_logprob_indices == [] # No prompt
  80. assert seq_group.sample_indices == [i] # Simple
  81. assert seq_group.seq_len is None # Decode
  82. assert seq_group.query_len is None # Decode
  83. def _gpu_advance_step(
  84. self, model_input: ModelInputForGPUWithSamplingMetadata,
  85. last_output: SamplerOutput
  86. ) -> ModelInputForGPUWithSamplingMetadata:
  87. # Currently, we expect "decode mode" only
  88. assert not model_input.is_prompt
  89. # Get num_seqs
  90. num_seqs = len(model_input.seq_lens)
  91. num_queries = len(model_input.query_lens)
  92. # Get output tokens GPU tensor
  93. sampled_token_ids = last_output.sampled_token_ids
  94. assert sampled_token_ids is not None
  95. # Update attn_metadata
  96. attn_metadata = model_input.attn_metadata
  97. assert isinstance(attn_metadata, FlashAttentionMetadata)
  98. attn_metadata.advance_step(model_input, sampled_token_ids,
  99. self.block_size, num_seqs, num_queries)
  100. # Update sampling_metadata
  101. sampling_metadata = model_input.sampling_metadata
  102. self._update_sampling_metadata(sampling_metadata, num_seqs,
  103. num_queries)
  104. # Create new input
  105. new_model_input = self._model_input_cls(
  106. input_tokens=model_input.input_tokens,
  107. input_positions=model_input.input_positions,
  108. attn_metadata=attn_metadata,
  109. seq_lens=attn_metadata.seq_lens,
  110. query_lens=model_input.query_lens,
  111. lora_mapping=model_input.lora_mapping,
  112. lora_requests=model_input.lora_requests,
  113. multi_modal_kwargs=model_input.multi_modal_kwargs,
  114. sampling_metadata=model_input.sampling_metadata,
  115. is_prompt=False,
  116. )
  117. # Ensure we skip CPU samples
  118. assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
  119. # We can reuse sampling tensors since every decode iteration is the same
  120. new_model_input.sampling_metadata.reuse_sampling_tensors = True
  121. if debug_advance_input:
  122. logger.debug("NEW INPUT: ")
  123. logger.debug(f" input_tokens = {new_model_input.input_tokens}")
  124. logger.debug(" input_positions = "
  125. f"{new_model_input.input_positions}")
  126. logger.debug(f" seq_lens = {new_model_input.seq_lens}")
  127. logger.debug(f" query_lens = {new_model_input.query_lens}")
  128. logger.debug(" attn_metadata:")
  129. logger.debug(" seq_lens_tensor: "
  130. f"{attn_metadata.seq_lens_tensor}")
  131. logger.debug(f" slot_mapping: {attn_metadata.slot_mapping}")
  132. logger.debug(f" block_tables: {attn_metadata.block_tables}")
  133. return new_model_input
  134. def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
  135. """Determines if draft_model_runner GPU multi-step can be used.
  136. Currently required conditions are:
  137. 1. Only decodes
  138. 2. Only flash-attn
  139. 3. No LORA
  140. 4. No prompt_adapter_config
  141. """
  142. if not allow_gpu_advance_step:
  143. return False
  144. # We allow multi-step GPU only in decode mode
  145. for seq_group in execute_model_req.seq_group_metadata_list:
  146. if seq_group.is_prompt:
  147. return False
  148. # TODO: Add support for other attn backends
  149. if self.attn_backend.get_name() != "flash-attn":
  150. return False
  151. # TODO: Add support for LORA
  152. if self.lora_config:
  153. return False
  154. # TODO: Add soft-tuning prompt adapter support
  155. if self.prompt_adapter_config:
  156. return False
  157. return True
  158. @torch.inference_mode()
  159. def execute_model(
  160. self,
  161. model_input: ModelInputForGPUWithSamplingMetadata,
  162. kv_caches: List[torch.Tensor],
  163. previous_hidden_states: Optional[torch.Tensor] = None,
  164. intermediate_tensors: Optional[IntermediateTensors] = None,
  165. num_steps: int = 1,
  166. ) -> Optional[List[SamplerOutput]]:
  167. """Executes num_steps forward passes with advacement of input tensors
  168. on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
  169. Optimizations used:
  170. 1. Input tensors are updated on the GPU directly
  171. 2. Skips GPU=>CPU serialization of sampler outputs (we don't need
  172. them since we do batch expansion later that uses GPU outputs)
  173. 3. Reuses sampling tensors (since we run only decodes and they have
  174. a repeating sampling logic)
  175. """
  176. # When num_steps == 1, we execute the fallback here for the GPU
  177. # advance_step, which runs prepare_inputs on CPU and for each spec
  178. # iteration invokes this function only once
  179. # (Look at multi-step-worker code)
  180. is_fallback = num_steps == 1
  181. if not is_fallback:
  182. # Since we do not broadcast data inside execute_model anymore,
  183. # we need to figure out the best way to support TP > 1 in this
  184. # case, because we will at least need to broadcast the sampled
  185. # tokens to all workers.
  186. if not self.is_driver_worker:
  187. raise ValueError("TP1DraftModelRunner only supports TP=1.")
  188. # Sanity
  189. if self.lora_config is not None:
  190. raise ValueError("TP1DraftModelRunner has no support for LORA")
  191. if self.prompt_adapter_config is not None:
  192. raise ValueError("TP1DraftModelRunner has no support for "
  193. "prompt_adapter_config")
  194. if model_input.multi_modal_kwargs:
  195. raise ValueError(
  196. "TP1DraftModelRunner has no support for multi_modal_kwargs"
  197. )
  198. else:
  199. if self.lora_config:
  200. assert model_input.lora_requests is not None
  201. assert model_input.lora_mapping is not None
  202. self.set_active_loras(model_input.lora_requests,
  203. model_input.lora_mapping)
  204. if self.prompt_adapter_config:
  205. assert model_input.prompt_adapter_requests is not None
  206. assert model_input.prompt_adapter_mapping is not None
  207. self.set_active_prompt_adapters(
  208. model_input.prompt_adapter_requests,
  209. model_input.prompt_adapter_mapping)
  210. self.attn_state.begin_forward(model_input)
  211. # Detect exec mode
  212. assert model_input.attn_metadata is not None
  213. use_cuda_graph = False
  214. if model_input.attn_metadata.num_prefills > 0:
  215. # In this case, execute_model(..) was called directly
  216. if num_steps > 1:
  217. raise ValueError(
  218. "execute_model(..) of draft_model_runner can be called "
  219. "directly only with a single-step prefill")
  220. else:
  221. # We can skip CPU samples for spec token generation.
  222. # (We do allow CPU samples for num_steps == 1 to support the
  223. # fallback case, where supports_gpu_multi_step(..) does not pass)
  224. model_input.sampling_metadata.skip_sampler_cpu_output = (
  225. not is_fallback)
  226. # Attn attr defines if we use cuda graphs
  227. use_cuda_graph = model_input.attn_metadata.use_cuda_graph
  228. # Get model
  229. if use_cuda_graph:
  230. graph_batch_size = model_input.input_tokens.shape[0]
  231. model_executable = (self.graph_runners[model_input.virtual_engine]
  232. [graph_batch_size])
  233. if previous_hidden_states is not None:
  234. hidden_states = torch.cat([
  235. previous_hidden_states,
  236. torch.empty([
  237. graph_batch_size - previous_hidden_states.shape[0],
  238. *previous_hidden_states.shape[1:]
  239. ],
  240. dtype=previous_hidden_states.dtype,
  241. device=previous_hidden_states.device)
  242. ])
  243. else:
  244. hidden_states = None
  245. else:
  246. model_executable = self.model
  247. hidden_states = previous_hidden_states
  248. outputs: List[SamplerOutput] = []
  249. for step in range(num_steps):
  250. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  251. kwargs = {"previous_hidden_states": hidden_states} \
  252. if previous_hidden_states is not None else {}
  253. # Run model
  254. hidden_states = model_executable(
  255. input_ids=model_input.input_tokens,
  256. positions=model_input.input_positions,
  257. kv_caches=kv_caches,
  258. attn_metadata=model_input.attn_metadata,
  259. intermediate_tensors=intermediate_tensors,
  260. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  261. device=self.device),
  262. **kwargs,
  263. )
  264. # Compute the logits.
  265. logits = self.model.compute_logits(hidden_states,
  266. model_input.sampling_metadata)
  267. # Sample the next token.
  268. outputs.append(
  269. self.model.sample(
  270. logits=logits,
  271. sampling_metadata=model_input.sampling_metadata,
  272. ))
  273. # Prepare inputs for the next step
  274. if step != num_steps - 1:
  275. model_input = self._gpu_advance_step(model_input, outputs[-1])
  276. return outputs