embedding_model_runner.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import dataclasses
  2. from typing import Any, Dict, List, Optional, Tuple, Type
  3. import torch
  4. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  5. LoRAConfig, ModelConfig, MultiModalConfig,
  6. ParallelConfig, PromptAdapterConfig,
  7. SchedulerConfig)
  8. from aphrodite.common.pooling_params import PoolingParams
  9. from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
  10. SequenceData, SequenceGroupMetadata)
  11. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  12. from aphrodite.multimodal import MultiModalInputs
  13. from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
  14. ModelInputForGPU,
  15. ModelInputForGPUBuilder)
  16. @dataclasses.dataclass(frozen=True)
  17. class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
  18. """
  19. Used by the EmbeddingModelRunner.
  20. """
  21. pooling_metadata: Optional["PoolingMetadata"] = None
  22. class EmbeddingModelRunner(
  23. GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
  24. _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
  25. ModelInputForGPUWithPoolingMetadata)
  26. _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
  27. def __init__(
  28. self,
  29. model_config: ModelConfig,
  30. parallel_config: ParallelConfig,
  31. scheduler_config: SchedulerConfig,
  32. device_config: DeviceConfig,
  33. cache_config: CacheConfig,
  34. load_config: LoadConfig,
  35. lora_config: Optional[LoRAConfig],
  36. kv_cache_dtype: Optional[str] = "auto",
  37. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  38. is_driver_worker: bool = False,
  39. multimodal_config: Optional[MultiModalConfig] = None,
  40. tp_rank: int = 0,
  41. ):
  42. super().__init__(model_config,
  43. parallel_config,
  44. scheduler_config,
  45. device_config,
  46. cache_config,
  47. load_config,
  48. lora_config=lora_config,
  49. kv_cache_dtype=kv_cache_dtype,
  50. is_driver_worker=is_driver_worker,
  51. prompt_adapter_config=prompt_adapter_config,
  52. multimodal_config=multimodal_config,
  53. tp_rank=tp_rank)
  54. @torch.inference_mode()
  55. def execute_model(
  56. self,
  57. model_input: ModelInputForGPUWithPoolingMetadata,
  58. kv_caches: List[torch.Tensor],
  59. intermediate_tensors: Optional[IntermediateTensors] = None,
  60. num_steps: int = 1,
  61. ) -> Optional[List[PoolerOutput]]:
  62. if num_steps > 1:
  63. raise ValueError(
  64. "EmbeddingModelRunner does not support multi-step execution.")
  65. if self.lora_config:
  66. assert model_input.lora_requests is not None
  67. assert model_input.lora_mapping is not None
  68. self.set_active_loras(model_input.lora_requests,
  69. model_input.lora_mapping)
  70. if self.prompt_adapter_config:
  71. assert model_input.prompt_adapter_requests is not None
  72. assert model_input.prompt_adapter_mapping is not None
  73. self.set_active_prompt_adapters(
  74. model_input.prompt_adapter_requests,
  75. model_input.prompt_adapter_mapping)
  76. # Currently cuda graph is only supported by the decode phase.
  77. assert model_input.attn_metadata is not None
  78. prefill_meta = model_input.attn_metadata.prefill_metadata
  79. decode_meta = model_input.attn_metadata.decode_metadata
  80. virtual_engine = model_input.virtual_engine
  81. if prefill_meta is None and decode_meta.use_cuda_graph:
  82. assert model_input.input_tokens is not None
  83. graph_batch_size = model_input.input_tokens.shape[0]
  84. model_executable = self.graph_runners[virtual_engine][
  85. graph_batch_size]
  86. else:
  87. model_executable = self.model
  88. num_layers = self.model_config.get_num_layers(self.parallel_config)
  89. kv_caches = [None] * num_layers
  90. execute_model_kwargs = {
  91. "input_ids":
  92. model_input.input_tokens,
  93. "positions":
  94. model_input.input_positions,
  95. "kv_caches":
  96. kv_caches,
  97. "attn_metadata":
  98. model_input.attn_metadata,
  99. **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
  100. device=self.device),
  101. }
  102. hidden_states = model_executable(**execute_model_kwargs)
  103. # Only perform pooling in the driver worker.
  104. if not self.is_driver_worker:
  105. return []
  106. return [
  107. self.model.pooler(hidden_states=hidden_states,
  108. pooling_metadata=model_input.pooling_metadata)
  109. ]
  110. def make_model_input_from_broadcasted_tensor_dict(
  111. self,
  112. tensor_dict: Dict[str,
  113. Any]) -> ModelInputForGPUWithPoolingMetadata:
  114. return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
  115. tensor_dict,
  116. attn_backend=self.attn_backend,
  117. )
  118. def prepare_model_input(
  119. self,
  120. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  121. virtual_engine: int = 0,
  122. finished_requests_ids: Optional[List[str]] = None
  123. ) -> ModelInputForGPUWithPoolingMetadata:
  124. assert seq_group_metadata_list is not None
  125. model_input = self._prepare_model_input_tensors(
  126. seq_group_metadata_list, finished_requests_ids)
  127. # Prepare PoolingMetadata.
  128. assert model_input.seq_lens is not None
  129. pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
  130. model_input.seq_lens)
  131. return dataclasses.replace(model_input,
  132. pooling_metadata=pooling_metadata)
  133. def _prepare_pooling(
  134. self,
  135. seq_group_metadata_list: List[SequenceGroupMetadata],
  136. prompt_lens: List[int],
  137. ) -> PoolingMetadata:
  138. """Prepare PoolingMetadata for the sequence group metadata list."""
  139. seq_groups: List[Tuple[List[int], PoolingParams]] = []
  140. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  141. seq_ids = list(seq_group_metadata.seq_data.keys())
  142. pooling_params = seq_group_metadata.pooling_params
  143. seq_groups.append((seq_ids, pooling_params))
  144. seq_data: Dict[int, SequenceData] = {}
  145. for seq_group_metadata in seq_group_metadata_list:
  146. seq_data.update(seq_group_metadata.seq_data)
  147. pooling_metadata = PoolingMetadata(
  148. seq_groups=seq_groups,
  149. seq_data=seq_data,
  150. prompt_lens=prompt_lens,
  151. )
  152. return pooling_metadata