123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- from typing import List, Optional, Union
- import torch
- from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
- from aphrodite.distributed import get_pp_group
- from aphrodite.multimodal import MultiModalInputs
- from aphrodite.task_handler.model_runner import (
- FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper,
- BatchPrefillWithPagedKVCacheWrapper, ModelInputForGPUWithSamplingMetadata,
- ModelRunner)
- class CFGModelRunner(ModelRunner):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- @torch.inference_mode()
- def model_execute(
- self,
- model_input: ModelInputForGPUWithSamplingMetadata,
- kv_caches: List[torch.Tensor],
- intermediate_tensors: Optional[IntermediateTensors] = None,
- num_steps: int = 1,
- ) -> torch.Tensor:
- if num_steps > 1:
- raise ValueError("num_steps > 1 is not supported in ModelRunner")
- if self.lora_config:
- assert model_input.lora_requests is not None
- assert model_input.lora_mapping is not None
- self.set_active_loras(model_input.lora_requests,
- model_input.lora_mapping)
- if self.prompt_adapter_config:
- assert model_input.prompt_adapter_requests is not None
- assert model_input.prompt_adapter_mapping is not None
- self.set_active_prompt_adapters(
- model_input.prompt_adapter_requests,
- model_input.prompt_adapter_mapping)
- if self.attn_backend.get_name() == "flashinfer":
- assert model_input.attn_metadata is not None
- assert model_input.input_tokens is not None
- if self.flashinfer_decode_workspace_buffer is None:
- self.flashinfer_decode_workspace_buffer = torch.empty(
- FLASHINFER_WORKSPACE_BUFFER_SIZE,
- dtype=torch.uint8,
- device=self.device)
- self.flashinfer_decode_wrapper = \
- BatchDecodeWithPagedKVCacheWrapper(
- self.flashinfer_decode_workspace_buffer, "NHD")
- self.flashinfer_prefill_workspace_buffer = torch.empty(
- FLASHINFER_WORKSPACE_BUFFER_SIZE,
- dtype=torch.uint8,
- device=self.device)
- self.flashinfer_prefill_wrapper = \
- BatchPrefillWithPagedKVCacheWrapper(
- self.flashinfer_prefill_workspace_buffer, "NHD")
- model_input.attn_metadata.prefill_wrapper = \
- self.flashinfer_prefill_wrapper
- if model_input.attn_metadata.use_cuda_graph:
- batch_size = model_input.input_tokens.shape[0]
- model_input.attn_metadata.decode_wrapper = self.graph_runners[
- model_input.
- virtual_engine][batch_size].flashinfer_decode_wrapper
- else:
- model_input.attn_metadata.decode_wrapper = \
- self.flashinfer_decode_wrapper
- model_input.attn_metadata.begin_forward()
- # Currently cuda graph is only supported by the decode phase.
- assert model_input.attn_metadata is not None
- prefill_meta = model_input.attn_metadata.prefill_metadata
- decode_meta = model_input.attn_metadata.decode_metadata
- # TODO(andoorve): We can remove this once all
- # virtual engines share the same kv cache.
- virtual_engine = model_input.virtual_engine
- if prefill_meta is None and decode_meta.use_cuda_graph:
- assert model_input.input_tokens is not None
- graph_batch_size = model_input.input_tokens.shape[0]
- model_executable = self.graph_runners[virtual_engine][
- graph_batch_size]
- else:
- model_executable = self.model
- multi_modal_kwargs = model_input.multi_modal_kwargs or {}
- seqlen_agnostic_kwargs = {
- "finished_requests_ids": model_input.finished_requests_ids,
- "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
- } if self.has_seqlen_agnostic else {}
- hidden_or_intermediate_states = model_executable(
- input_ids=model_input.input_tokens,
- positions=model_input.input_positions,
- kv_caches=kv_caches,
- attn_metadata=model_input.attn_metadata,
- intermediate_tensors=intermediate_tensors,
- **MultiModalInputs.as_kwargs(multi_modal_kwargs,
- device=self.device),
- **seqlen_agnostic_kwargs)
- return hidden_or_intermediate_states
- @torch.inference_mode()
- def get_logits(
- self,
- hidden_or_intermediate_states: torch.Tensor,
- model_input: ModelInputForGPUWithSamplingMetadata,
- ) -> torch.Tensor:
- return self.model._get_logits(hidden_or_intermediate_states,
- model_input.sampling_metadata)
- @torch.inference_mode()
- def compute_logits(
- self,
- logits: torch.Tensor,
- model_input: ModelInputForGPUWithSamplingMetadata,
- ) -> torch.Tensor:
- return self.model.compute_logits(logits,
- model_input.sampling_metadata)
- @torch.inference_mode()
- def do_sample(
- self,
- logits: torch.Tensor,
- model_input: ModelInputForGPUWithSamplingMetadata,
- ):
- if not self.is_driver_worker:
- return []
- # Sample the next token.
- output: SamplerOutput = self.model.sample(
- logits=logits,
- sampling_metadata=model_input.sampling_metadata,
- )
- if self.return_hidden_states:
- raise NotImplementedError("return_hidden_states is not supported "
- "in CFGModelRunner")
- return [output]
- @torch.inference_mode()
- def execute_model(
- self,
- model_input: ModelInputForGPUWithSamplingMetadata,
- kv_caches: List[torch.Tensor],
- intermediate_tensors: Optional[IntermediateTensors] = None,
- num_steps: int = 1,
- ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
- hidden_or_intermediate_states = self.model_execute(
- model_input, kv_caches, intermediate_tensors, num_steps)
- if not get_pp_group().is_last_rank:
- return hidden_or_intermediate_states
- hidden_or_intermediate_states = self.get_logits(
- hidden_or_intermediate_states, model_input)
- logits = self.compute_logits(hidden_or_intermediate_states, model_input)
- return self.do_sample(logits, model_input)
|