123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- from typing import List, Optional
- import torch
- from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
- LoRAConfig, ModelConfig, MultiModalConfig,
- ParallelConfig, PromptAdapterConfig,
- SchedulerConfig)
- from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
- SequenceGroupMetadata)
- from aphrodite.task_handler.model_runner import (
- ModelInputForGPUWithSamplingMetadata, ModelRunner)
- class TP1DraftModelRunner(ModelRunner):
- """Specialized model runner for speculative decoding draft model.
- Since the draft model always execute k forward passes consecutively to
- generate k speculative tokens in a single speculative decoding step,
- we could get rid of most CPU-GPU synchronization and data transfer
- overheads by keeping model input and output tensors on GPU all the time.
- This runner is still under development so there's no performance gain
- at this moment. Currently we adopt a temporary solution that caches the
- seq_group_metadata_list for multi-step execution, so that we can
- leverage existing prepare_model_input to be compatible with the current
- execution flow, but we plan to remove this cache and avoid calling
- prepare_model_input in execute_model at all.
-
- The detail development plan includes:
- 1. Use "update_model_input" to update existing model_input without
- creating a new one.
- 2. Improve the performance of "update_model_input" with a GPU kernel.
- 3. Support TP > 1 (this requires some designs because we do not expect
- any broadcasting inside execute_model).
- """
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- cache_config: CacheConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- kv_cache_dtype: Optional[str] = "auto",
- is_driver_worker: bool = False,
- multimodal_config: Optional[MultiModalConfig] = None,
- prompt_adapter_config: Optional[PromptAdapterConfig] = None,
- return_hidden_states: bool = False,
- ):
- if return_hidden_states:
- raise ValueError(
- "return_hidden_states is not supported for TP1DraftModelRunner."
- )
- super().__init__(
- model_config=model_config,
- parallel_config=parallel_config,
- scheduler_config=scheduler_config,
- device_config=device_config,
- cache_config=cache_config,
- load_config=load_config,
- lora_config=lora_config,
- kv_cache_dtype=kv_cache_dtype,
- is_driver_worker=is_driver_worker,
- multimodal_config=multimodal_config,
- prompt_adapter_config=prompt_adapter_config,
- return_hidden_states=return_hidden_states,
- )
-
-
- self.cached_seq_group_metadata_list: Optional[
- List[SequenceGroupMetadata]] = None
- def prepare_model_input(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- virtual_engine: int = 0,
- finished_requests_ids: Optional[List[str]] = None
- ) -> ModelInputForGPUWithSamplingMetadata:
- """A temporary solution that caches the seq_group_metadata_list
- for multi-step execution.
- TODO: In-place update model_input and remove this function.
- """
- self.cached_seq_group_metadata_list = seq_group_metadata_list
- return super().prepare_model_input(
- seq_group_metadata_list,
- finished_requests_ids=finished_requests_ids)
- def update_model_input(
- self, model_input: ModelInputForGPUWithSamplingMetadata,
- last_output: SamplerOutput
- ) -> ModelInputForGPUWithSamplingMetadata:
- """Prepare the model inputs for the next step.
- TODO: In-place update model_input instead of calling
- prepare_model_input.
- """
-
- assert self.cached_seq_group_metadata_list is not None
- for seq_group_metadata, sequence_group_outputs in zip(
- self.cached_seq_group_metadata_list, last_output.outputs):
- seq_group_metadata.is_prompt = False
- for seq_output in sequence_group_outputs.samples:
- seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
- token_id = seq_output.output_token
- token_logprob = seq_output.logprobs[token_id]
- seq.append_token_id(token_id, token_logprob.logprob)
- seq.update_num_computed_tokens(1)
- return self.prepare_model_input(self.cached_seq_group_metadata_list)
- @torch.inference_mode()
- def execute_model(
- self,
- model_input: ModelInputForGPUWithSamplingMetadata,
- kv_caches: List[torch.Tensor],
- intermediate_tensors: Optional[IntermediateTensors] = None,
- num_steps: int = 1,
- ) -> Optional[List[SamplerOutput]]:
-
-
-
-
- if not self.is_driver_worker:
- raise ValueError("TP1DraftModelRunner only supports TP=1.")
- if self.lora_config:
- assert model_input.lora_requests is not None
- assert model_input.lora_mapping is not None
- self.set_active_loras(model_input.lora_requests,
- model_input.lora_mapping)
- if self.prompt_adapter_config:
- assert model_input.prompt_adapter_requests is not None
- assert model_input.prompt_adapter_mapping is not None
- self.set_active_prompt_adapters(
- model_input.prompt_adapter_requests,
- model_input.prompt_adapter_mapping)
- virtual_engine = model_input.virtual_engine
- outputs: List[SamplerOutput] = []
- for step in range(num_steps):
-
- assert model_input.attn_metadata is not None
- prefill_meta = model_input.attn_metadata.prefill_metadata
- decode_meta = model_input.attn_metadata.decode_metadata
- if prefill_meta is None and decode_meta.use_cuda_graph:
- assert model_input.input_tokens is not None
- graph_batch_size = model_input.input_tokens.shape[0]
- model_executable = (
- self.graph_runners[virtual_engine][graph_batch_size])
- else:
- model_executable = self.model
- multi_modal_kwargs = model_input.multi_modal_kwargs or {}
- hidden_states = model_executable(
- input_ids=model_input.input_tokens,
- positions=model_input.input_positions,
- kv_caches=kv_caches,
- attn_metadata=model_input.attn_metadata,
- intermediate_tensors=intermediate_tensors,
- **multi_modal_kwargs,
- )
-
- logits = self.model.compute_logits(hidden_states,
- model_input.sampling_metadata)
-
- outputs.append(
- self.model.sample(
- logits=logits,
- sampling_metadata=model_input.sampling_metadata,
- ))
-
- if step != num_steps - 1:
- model_input = self.update_model_input(model_input, outputs[-1])
- return outputs
|