embedding_model_runner.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import dataclasses
  2. from typing import Any, Dict, List, Optional, Tuple, Type
  3. import torch
  4. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  5. LoRAConfig, ModelConfig, MultiModalConfig,
  6. ParallelConfig, PromptAdapterConfig,
  7. SchedulerConfig)
  8. from aphrodite.common.pooling_params import PoolingParams
  9. from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
  10. SequenceData, SequenceGroupMetadata)
  11. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  12. from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
  13. ModelInputForGPU)
  14. @dataclasses.dataclass(frozen=True)
  15. class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
  16. """
  17. Used by the EmbeddingModelRunner.
  18. """
  19. pooling_metadata: Optional["PoolingMetadata"] = None
  20. class EmbeddingModelRunner(
  21. GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
  22. _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
  23. ModelInputForGPUWithPoolingMetadata)
  24. def __init__(
  25. self,
  26. model_config: ModelConfig,
  27. parallel_config: ParallelConfig,
  28. scheduler_config: SchedulerConfig,
  29. device_config: DeviceConfig,
  30. cache_config: CacheConfig,
  31. load_config: LoadConfig,
  32. lora_config: Optional[LoRAConfig],
  33. kv_cache_dtype: Optional[str] = "auto",
  34. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  35. is_driver_worker: bool = False,
  36. multimodal_config: Optional[MultiModalConfig] = None,
  37. ):
  38. super().__init__(model_config,
  39. parallel_config,
  40. scheduler_config,
  41. device_config,
  42. cache_config,
  43. load_config,
  44. lora_config=lora_config,
  45. kv_cache_dtype=kv_cache_dtype,
  46. is_driver_worker=is_driver_worker,
  47. prompt_adapter_config=prompt_adapter_config,
  48. multimodal_config=multimodal_config)
  49. @torch.inference_mode()
  50. def execute_model(
  51. self,
  52. model_input: ModelInputForGPUWithPoolingMetadata,
  53. kv_caches: List[torch.Tensor],
  54. intermediate_tensors: Optional[IntermediateTensors] = None,
  55. num_steps: int = 1,
  56. ) -> Optional[List[PoolerOutput]]:
  57. if num_steps > 1:
  58. raise ValueError(
  59. "EmbeddingModelRunner does not support multi-step execution.")
  60. if self.lora_config:
  61. assert model_input.lora_requests is not None
  62. assert model_input.lora_mapping is not None
  63. self.set_active_loras(model_input.lora_requests,
  64. model_input.lora_mapping)
  65. if self.prompt_adapter_config:
  66. assert model_input.prompt_adapter_requests is not None
  67. assert model_input.prompt_adapter_mapping is not None
  68. self.set_active_prompt_adapters(
  69. model_input.prompt_adapter_requests,
  70. model_input.prompt_adapter_mapping)
  71. # Currently cuda graph is only supported by the decode phase.
  72. assert model_input.attn_metadata is not None
  73. prefill_meta = model_input.attn_metadata.prefill_metadata
  74. decode_meta = model_input.attn_metadata.decode_metadata
  75. virtual_engine = model_input.virtual_engine
  76. if prefill_meta is None and decode_meta.use_cuda_graph:
  77. assert model_input.input_tokens is not None
  78. graph_batch_size = model_input.input_tokens.shape[0]
  79. model_executable = self.graph_runners[virtual_engine][
  80. graph_batch_size]
  81. else:
  82. model_executable = self.model
  83. num_layers = self.model_config.get_num_layers(self.parallel_config)
  84. kv_caches = [None] * num_layers
  85. execute_model_kwargs = {
  86. "input_ids": model_input.input_tokens,
  87. "positions": model_input.input_positions,
  88. "kv_caches": kv_caches,
  89. "attn_metadata": model_input.attn_metadata,
  90. **(model_input.multi_modal_kwargs or {}),
  91. }
  92. hidden_states = model_executable(**execute_model_kwargs)
  93. # Only perform pooling in the driver worker.
  94. if not self.is_driver_worker:
  95. return []
  96. return [
  97. self.model.pooler(hidden_states=hidden_states,
  98. pooling_metadata=model_input.pooling_metadata)
  99. ]
  100. def make_model_input_from_broadcasted_tensor_dict(
  101. self,
  102. tensor_dict: Dict[str,
  103. Any]) -> ModelInputForGPUWithPoolingMetadata:
  104. return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
  105. tensor_dict,
  106. attn_backend=self.attn_backend,
  107. )
  108. def prepare_model_input(
  109. self,
  110. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  111. virtual_engine: int = 0,
  112. finished_requests_ids: Optional[List[str]] = None
  113. ) -> ModelInputForGPUWithPoolingMetadata:
  114. assert seq_group_metadata_list is not None
  115. model_input = self._prepare_model_input_tensors(
  116. seq_group_metadata_list, finished_requests_ids)
  117. # Prepare PoolingMetadata.
  118. assert model_input.seq_lens is not None
  119. pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
  120. model_input.seq_lens)
  121. return dataclasses.replace(model_input,
  122. pooling_metadata=pooling_metadata)
  123. def _prepare_pooling(
  124. self,
  125. seq_group_metadata_list: List[SequenceGroupMetadata],
  126. prompt_lens: List[int],
  127. ) -> PoolingMetadata:
  128. """Prepare PoolingMetadata for the sequence group metadata list."""
  129. seq_groups: List[Tuple[List[int], PoolingParams]] = []
  130. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  131. seq_ids = list(seq_group_metadata.seq_data.keys())
  132. pooling_params = seq_group_metadata.pooling_params
  133. seq_groups.append((seq_ids, pooling_params))
  134. seq_data: Dict[int, SequenceData] = {}
  135. for seq_group_metadata in seq_group_metadata_list:
  136. seq_data.update(seq_group_metadata.seq_data)
  137. pooling_metadata = PoolingMetadata(
  138. seq_groups=seq_groups,
  139. seq_data=seq_data,
  140. prompt_lens=prompt_lens,
  141. )
  142. return pooling_metadata