embedding_model_runner.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import dataclasses
  2. from typing import Any, Dict, List, Optional, Tuple, Type
  3. import torch
  4. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  5. LoRAConfig, ModelConfig, ParallelConfig,
  6. PromptAdapterConfig, SchedulerConfig)
  7. from aphrodite.common.pooling_params import PoolingParams
  8. from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
  9. SequenceData, SequenceGroupMetadata)
  10. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  11. from aphrodite.multimodal import MultiModalInputs
  12. from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
  13. ModelInputForGPU,
  14. ModelInputForGPUBuilder)
  15. @dataclasses.dataclass(frozen=True)
  16. class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
  17. """
  18. Used by the EmbeddingModelRunner.
  19. """
  20. pooling_metadata: Optional["PoolingMetadata"] = None
  21. class EmbeddingModelRunner(
  22. GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
  23. _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
  24. ModelInputForGPUWithPoolingMetadata)
  25. _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
  26. def __init__(
  27. self,
  28. model_config: ModelConfig,
  29. parallel_config: ParallelConfig,
  30. scheduler_config: SchedulerConfig,
  31. device_config: DeviceConfig,
  32. cache_config: CacheConfig,
  33. load_config: LoadConfig,
  34. lora_config: Optional[LoRAConfig],
  35. kv_cache_dtype: Optional[str] = "auto",
  36. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  37. is_driver_worker: bool = False,
  38. tp_rank: int = 0,
  39. ):
  40. super().__init__(model_config,
  41. parallel_config,
  42. scheduler_config,
  43. device_config,
  44. cache_config,
  45. load_config,
  46. lora_config=lora_config,
  47. kv_cache_dtype=kv_cache_dtype,
  48. is_driver_worker=is_driver_worker,
  49. prompt_adapter_config=prompt_adapter_config,
  50. tp_rank=tp_rank)
  51. @torch.inference_mode()
  52. def execute_model(
  53. self,
  54. model_input: ModelInputForGPUWithPoolingMetadata,
  55. kv_caches: List[torch.Tensor],
  56. intermediate_tensors: Optional[IntermediateTensors] = None,
  57. num_steps: int = 1,
  58. ) -> Optional[List[PoolerOutput]]:
  59. if num_steps > 1:
  60. raise ValueError(
  61. "EmbeddingModelRunner does not support multi-step execution.")
  62. if self.lora_config:
  63. assert model_input.lora_requests is not None
  64. assert model_input.lora_mapping is not None
  65. self.set_active_loras(model_input.lora_requests,
  66. model_input.lora_mapping)
  67. if self.prompt_adapter_config:
  68. assert model_input.prompt_adapter_requests is not None
  69. assert model_input.prompt_adapter_mapping is not None
  70. self.set_active_prompt_adapters(
  71. model_input.prompt_adapter_requests,
  72. model_input.prompt_adapter_mapping)
  73. # Currently cuda graph is only supported by the decode phase.
  74. assert model_input.attn_metadata is not None
  75. prefill_meta = model_input.attn_metadata.prefill_metadata
  76. decode_meta = model_input.attn_metadata.decode_metadata
  77. virtual_engine = model_input.virtual_engine
  78. if prefill_meta is None and decode_meta.use_cuda_graph:
  79. assert model_input.input_tokens is not None
  80. graph_batch_size = model_input.input_tokens.shape[0]
  81. model_executable = self.graph_runners[virtual_engine][
  82. graph_batch_size]
  83. else:
  84. model_executable = self.model
  85. num_layers = self.model_config.get_num_layers(self.parallel_config)
  86. kv_caches = [None] * num_layers
  87. execute_model_kwargs = {
  88. "input_ids":
  89. model_input.input_tokens,
  90. "positions":
  91. model_input.input_positions,
  92. "kv_caches":
  93. kv_caches,
  94. "attn_metadata":
  95. model_input.attn_metadata,
  96. **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
  97. device=self.device),
  98. }
  99. hidden_states = model_executable(**execute_model_kwargs)
  100. # Only perform pooling in the driver worker.
  101. if not self.is_driver_worker:
  102. return []
  103. return [
  104. self.model.pooler(hidden_states=hidden_states,
  105. pooling_metadata=model_input.pooling_metadata)
  106. ]
  107. def make_model_input_from_broadcasted_tensor_dict(
  108. self,
  109. tensor_dict: Dict[str,
  110. Any]) -> ModelInputForGPUWithPoolingMetadata:
  111. return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
  112. tensor_dict,
  113. attn_backend=self.attn_backend,
  114. )
  115. def prepare_model_input(
  116. self,
  117. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  118. virtual_engine: int = 0,
  119. finished_requests_ids: Optional[List[str]] = None
  120. ) -> ModelInputForGPUWithPoolingMetadata:
  121. assert seq_group_metadata_list is not None
  122. model_input = self._prepare_model_input_tensors(
  123. seq_group_metadata_list, finished_requests_ids)
  124. # Prepare PoolingMetadata.
  125. assert model_input.seq_lens is not None
  126. pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
  127. model_input.seq_lens)
  128. return dataclasses.replace(model_input,
  129. pooling_metadata=pooling_metadata)
  130. def _prepare_pooling(
  131. self,
  132. seq_group_metadata_list: List[SequenceGroupMetadata],
  133. prompt_lens: List[int],
  134. ) -> PoolingMetadata:
  135. """Prepare PoolingMetadata for the sequence group metadata list."""
  136. seq_groups: List[Tuple[List[int], PoolingParams]] = []
  137. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  138. seq_ids = list(seq_group_metadata.seq_data.keys())
  139. pooling_params = seq_group_metadata.pooling_params
  140. seq_groups.append((seq_ids, pooling_params))
  141. seq_data: Dict[int, SequenceData] = {}
  142. for seq_group_metadata in seq_group_metadata_list:
  143. seq_data.update(seq_group_metadata.seq_data)
  144. pooling_metadata = PoolingMetadata(
  145. seq_groups=seq_groups,
  146. seq_data=seq_data,
  147. prompt_lens=prompt_lens,
  148. )
  149. return pooling_metadata