embedding_model_runner.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import dataclasses
  2. from typing import Any, Dict, List, Optional, Tuple, Type
  3. import torch
  4. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  5. LoRAConfig, ModelConfig, ParallelConfig,
  6. SchedulerConfig, VisionLanguageConfig)
  7. from aphrodite.common.pooling_params import PoolingParams
  8. from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
  9. SequenceData, SequenceGroupMetadata)
  10. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  11. from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
  12. ModelInputForGPU)
  13. @dataclasses.dataclass(frozen=True)
  14. class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
  15. """
  16. Used by the EmbeddingModelRunner.
  17. """
  18. pooling_metadata: Optional["PoolingMetadata"] = None
  19. class EmbeddingModelRunner(
  20. GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
  21. _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
  22. ModelInputForGPUWithPoolingMetadata)
  23. def __init__(
  24. self,
  25. model_config: ModelConfig,
  26. parallel_config: ParallelConfig,
  27. scheduler_config: SchedulerConfig,
  28. device_config: DeviceConfig,
  29. cache_config: CacheConfig,
  30. load_config: LoadConfig,
  31. lora_config: Optional[LoRAConfig],
  32. kv_cache_dtype: Optional[str] = "auto",
  33. is_driver_worker: bool = False,
  34. vision_language_config: Optional[VisionLanguageConfig] = None,
  35. ):
  36. super().__init__(model_config,
  37. parallel_config,
  38. scheduler_config,
  39. device_config,
  40. cache_config,
  41. load_config,
  42. lora_config=lora_config,
  43. kv_cache_dtype=kv_cache_dtype,
  44. is_driver_worker=is_driver_worker,
  45. vision_language_config=vision_language_config)
  46. @torch.inference_mode()
  47. def execute_model(
  48. self,
  49. model_input: ModelInputForGPUWithPoolingMetadata,
  50. kv_caches: List[torch.Tensor],
  51. intermediate_tensors: Optional[IntermediateTensors] = None,
  52. num_steps: int = 1,
  53. ) -> Optional[List[PoolerOutput]]:
  54. if num_steps > 1:
  55. raise ValueError(
  56. "EmbeddingModelRunner does not support multi-step execution.")
  57. if self.lora_config:
  58. assert model_input.lora_requests is not None
  59. assert model_input.lora_mapping is not None
  60. self.set_active_loras(model_input.lora_requests,
  61. model_input.lora_mapping)
  62. # Currently cuda graph is only supported by the decode phase.
  63. assert model_input.attn_metadata is not None
  64. prefill_meta = model_input.attn_metadata.prefill_metadata
  65. decode_meta = model_input.attn_metadata.decode_metadata
  66. virtual_engine = model_input.virtual_engine
  67. if prefill_meta is None and decode_meta.use_cuda_graph:
  68. assert model_input.input_tokens is not None
  69. graph_batch_size = model_input.input_tokens.shape[0]
  70. model_executable = self.graph_runners[virtual_engine][
  71. graph_batch_size]
  72. else:
  73. model_executable = self.model
  74. num_layers = self.model_config.get_num_layers(self.parallel_config)
  75. kv_caches = [None] * num_layers
  76. execute_model_kwargs = {
  77. "input_ids": model_input.input_tokens,
  78. "positions": model_input.input_positions,
  79. "kv_caches": kv_caches,
  80. "attn_metadata": model_input.attn_metadata,
  81. **(model_input.multi_modal_kwargs or {}),
  82. }
  83. hidden_states = model_executable(**execute_model_kwargs)
  84. # Only perform pooling in the driver worker.
  85. if not self.is_driver_worker:
  86. return []
  87. return [
  88. self.model.pooler(hidden_states=hidden_states,
  89. pooling_metadata=model_input.pooling_metadata)
  90. ]
  91. def make_model_input_from_broadcasted_tensor_dict(
  92. self,
  93. tensor_dict: Dict[str,
  94. Any]) -> ModelInputForGPUWithPoolingMetadata:
  95. return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
  96. tensor_dict,
  97. attn_backend=self.attn_backend,
  98. )
  99. def prepare_model_input(
  100. self,
  101. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  102. virtual_engine: int = 0,
  103. finished_requests_ids: Optional[List[str]] = None
  104. ) -> ModelInputForGPUWithPoolingMetadata:
  105. assert seq_group_metadata_list is not None
  106. model_input = self._prepare_model_input_tensors(
  107. seq_group_metadata_list, finished_requests_ids)
  108. # Prepare PoolingMetadata.
  109. assert model_input.seq_lens is not None
  110. pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
  111. model_input.seq_lens)
  112. return dataclasses.replace(model_input,
  113. pooling_metadata=pooling_metadata)
  114. def _prepare_pooling(
  115. self,
  116. seq_group_metadata_list: List[SequenceGroupMetadata],
  117. prompt_lens: List[int],
  118. ) -> PoolingMetadata:
  119. """Prepare PoolingMetadata for the sequence group metadata list."""
  120. seq_groups: List[Tuple[List[int], PoolingParams]] = []
  121. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  122. seq_ids = list(seq_group_metadata.seq_data.keys())
  123. pooling_params = seq_group_metadata.pooling_params
  124. seq_groups.append((seq_ids, pooling_params))
  125. seq_data: Dict[int, SequenceData] = {}
  126. for seq_group_metadata in seq_group_metadata_list:
  127. seq_data.update(seq_group_metadata.seq_data)
  128. pooling_metadata = PoolingMetadata(
  129. seq_groups=seq_groups,
  130. seq_data=seq_data,
  131. prompt_lens=prompt_lens,
  132. )
  133. return pooling_metadata