123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- from typing import Dict, List, Optional, Set, Tuple
- import torch
- from aphrodite.attention import AttentionMetadata
- from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
- LoRAConfig, ModelConfig, ParallelConfig,
- SchedulerConfig, VisionLanguageConfig)
- from aphrodite.common.pooling_params import PoolingParams
- from aphrodite.common.sequence import (PoolerOutput, SequenceData,
- SequenceGroupMetadata)
- from aphrodite.distributed import broadcast_tensor_dict
- from aphrodite.lora.layers import LoRAMapping
- from aphrodite.lora.request import LoRARequest
- from aphrodite.modeling.pooling_metadata import PoolingMetadata
- from aphrodite.task_handler.model_runner import ModelRunner
- class EmbeddingModelRunner(ModelRunner):
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- cache_config: CacheConfig,
- load_config: LoadConfig,
- lora_config: Optional[LoRAConfig],
- kv_cache_dtype: Optional[str] = "auto",
- is_driver_worker: bool = False,
- vision_language_config: Optional[VisionLanguageConfig] = None,
- ):
- super().__init__(model_config,
- parallel_config,
- scheduler_config,
- device_config,
- cache_config,
- load_config,
- lora_config=lora_config,
- kv_cache_dtype=kv_cache_dtype,
- is_driver_worker=is_driver_worker,
- vision_language_config=vision_language_config)
- @torch.inference_mode()
- def execute_model(
- self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
- kv_caches: List[torch.Tensor],
- ) -> Optional[PoolerOutput]:
- (input_tokens, input_positions, attn_metadata, pooling_metadata,
- lora_requests, lora_mapping, multi_modal_input
- ) = self.prepare_input_tensors(seq_group_metadata_list)
- if self.lora_config:
- self.set_active_loras(lora_requests, lora_mapping)
- # Currently cuda graph is only supported by the decode phase.
- prefill_meta = attn_metadata.prefill_metadata
- decode_meta = attn_metadata.decode_metadata
- if prefill_meta is None and decode_meta.use_cuda_graph:
- graph_batch_size = input_tokens.shape[0]
- model_executable = self.graph_runners[graph_batch_size]
- else:
- model_executable = self.model
- num_layers = self.model_config.get_num_layers(self.parallel_config)
- kv_caches = [None] * num_layers
- execute_model_kwargs = {
- "input_ids": input_tokens,
- "positions": input_positions,
- "kv_caches": kv_caches,
- "attn_metadata": attn_metadata,
- }
- if self.vision_language_config:
- execute_model_kwargs.update({"image_input": multi_modal_input})
- hidden_states = model_executable(**execute_model_kwargs)
- return self.model.pooler(hidden_states=hidden_states,
- pooling_metadata=pooling_metadata)
- def prepare_input_tensors(
- self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
- Set[LoRARequest], LoRAMapping, torch.Tensor]:
- if self.is_driver_worker:
- assert seq_group_metadata_list is not None
- # Prepare input tensors.
- (
- input_tokens,
- input_positions,
- attn_metadata,
- seq_lens,
- _,
- lora_mapping,
- lora_requests,
- multi_modal_input,
- slot_mapping,
- num_prefill_tokens,
- num_decode_tokens,
- num_prefills,
- ) = self._prepare_model_input(seq_group_metadata_list)
- # Prepare PoolingMetadata
- pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
- seq_lens)
- metadata_dict = {
- "input_tokens": input_tokens,
- "input_positions": input_positions,
- "lora_requests": lora_requests,
- "lora_mapping": lora_mapping,
- "multi_modal_input": multi_modal_input,
- "num_prefill_tokens": num_prefill_tokens,
- "num_decode_tokens": num_decode_tokens,
- "slot_mapping": slot_mapping,
- "num_prefills": num_prefills,
- }
- if attn_metadata:
- metadata_dict.update(attn_metadata.asdict_zerocopy())
- broadcast_tensor_dict(metadata_dict, src=0)
- else:
- metadata_dict = broadcast_tensor_dict(src=0)
- input_tokens = metadata_dict.pop("input_tokens")
- input_positions = metadata_dict.pop("input_positions")
- lora_mapping = metadata_dict.pop("lora_mapping")
- lora_requests = metadata_dict.pop("lora_requests")
- multi_modal_input = metadata_dict.pop("multi_modal_input")
- if metadata_dict:
- attn_metadata = self.attn_backend.make_metadata(
- **metadata_dict)
- else:
- attn_metadata = None
- pooling_metadata = PoolingMetadata(seq_groups=None,
- seq_data=None,
- prompt_lens=None)
- return (input_tokens, input_positions, attn_metadata, pooling_metadata,
- lora_requests, lora_mapping, multi_modal_input)
- def _prepare_pooling(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- ) -> PoolingMetadata:
- """Prepare PoolingMetadata for the sequence group metadata list."""
- seq_groups: List[Tuple[List[int], PoolingParams]] = []
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- pooling_params = seq_group_metadata.pooling_params
- seq_groups.append((seq_ids, pooling_params))
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
- pooling_metadata = PoolingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- )
- return pooling_metadata
|