123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- import dataclasses
- from typing import Any, Dict, List, Optional, Tuple, Type
- import torch
- from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
- LoRAConfig, ModelConfig, MultiModalConfig,
- ParallelConfig, PromptAdapterConfig,
- SchedulerConfig)
- from aphrodite.common.pooling_params import PoolingParams
- from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
- SequenceData, SequenceGroupMetadata)
- from aphrodite.modeling.pooling_metadata import PoolingMetadata
- from aphrodite.multimodal import MultiModalInputs
- from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
- ModelInputForGPU,
- ModelInputForGPUBuilder)
- @dataclasses.dataclass(frozen=True)
- class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
- """
- Used by the EmbeddingModelRunner.
- """
- pooling_metadata: Optional["PoolingMetadata"] = None
- class EmbeddingModelRunner(
- GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
- _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
- ModelInputForGPUWithPoolingMetadata)
- _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- cache_config: CacheConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- kv_cache_dtype: Optional[str] = "auto",
- prompt_adapter_config: Optional[PromptAdapterConfig] = None,
- is_driver_worker: bool = False,
- multimodal_config: Optional[MultiModalConfig] = None,
- tp_rank: int = 0,
- ):
- super().__init__(model_config,
- parallel_config,
- scheduler_config,
- device_config,
- cache_config,
- load_config,
- lora_config=lora_config,
- kv_cache_dtype=kv_cache_dtype,
- is_driver_worker=is_driver_worker,
- prompt_adapter_config=prompt_adapter_config,
- multimodal_config=multimodal_config,
- tp_rank=tp_rank)
- @torch.inference_mode()
- def execute_model(
- self,
- model_input: ModelInputForGPUWithPoolingMetadata,
- kv_caches: List[torch.Tensor],
- intermediate_tensors: Optional[IntermediateTensors] = None,
- num_steps: int = 1,
- ) -> Optional[List[PoolerOutput]]:
- if num_steps > 1:
- raise ValueError(
- "EmbeddingModelRunner does not support multi-step execution.")
- if self.lora_config:
- assert model_input.lora_requests is not None
- assert model_input.lora_mapping is not None
- self.set_active_loras(model_input.lora_requests,
- model_input.lora_mapping)
- if self.prompt_adapter_config:
- assert model_input.prompt_adapter_requests is not None
- assert model_input.prompt_adapter_mapping is not None
- self.set_active_prompt_adapters(
- model_input.prompt_adapter_requests,
- model_input.prompt_adapter_mapping)
- # Currently cuda graph is only supported by the decode phase.
- assert model_input.attn_metadata is not None
- prefill_meta = model_input.attn_metadata.prefill_metadata
- decode_meta = model_input.attn_metadata.decode_metadata
- virtual_engine = model_input.virtual_engine
- if prefill_meta is None and decode_meta.use_cuda_graph:
- assert model_input.input_tokens is not None
- graph_batch_size = model_input.input_tokens.shape[0]
- model_executable = self.graph_runners[virtual_engine][
- graph_batch_size]
- else:
- model_executable = self.model
- num_layers = self.model_config.get_num_layers(self.parallel_config)
- kv_caches = [None] * num_layers
- execute_model_kwargs = {
- "input_ids":
- model_input.input_tokens,
- "positions":
- model_input.input_positions,
- "kv_caches":
- kv_caches,
- "attn_metadata":
- model_input.attn_metadata,
- **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
- device=self.device),
- }
- hidden_states = model_executable(**execute_model_kwargs)
- # Only perform pooling in the driver worker.
- if not self.is_driver_worker:
- return []
- return [
- self.model.pooler(hidden_states=hidden_states,
- pooling_metadata=model_input.pooling_metadata)
- ]
- def make_model_input_from_broadcasted_tensor_dict(
- self,
- tensor_dict: Dict[str,
- Any]) -> ModelInputForGPUWithPoolingMetadata:
- return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
- tensor_dict,
- attn_backend=self.attn_backend,
- )
- def prepare_model_input(
- self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
- virtual_engine: int = 0,
- finished_requests_ids: Optional[List[str]] = None
- ) -> ModelInputForGPUWithPoolingMetadata:
- assert seq_group_metadata_list is not None
- model_input = self._prepare_model_input_tensors(
- seq_group_metadata_list, finished_requests_ids)
- # Prepare PoolingMetadata.
- assert model_input.seq_lens is not None
- pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
- model_input.seq_lens)
- return dataclasses.replace(model_input,
- pooling_metadata=pooling_metadata)
- def _prepare_pooling(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- ) -> PoolingMetadata:
- """Prepare PoolingMetadata for the sequence group metadata list."""
- seq_groups: List[Tuple[List[int], PoolingParams]] = []
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- pooling_params = seq_group_metadata.pooling_params
- seq_groups.append((seq_ids, pooling_params))
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
- pooling_metadata = PoolingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- )
- return pooling_metadata
|