1
0

embedding_model_runner.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from typing import Dict, List, Optional, Set, Tuple
  2. import torch
  3. from aphrodite.attention import AttentionMetadata
  4. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  5. LoRAConfig, ModelConfig, ParallelConfig,
  6. SchedulerConfig, VisionLanguageConfig)
  7. from aphrodite.common.pooling_params import PoolingParams
  8. from aphrodite.common.sequence import (PoolerOutput, SequenceData,
  9. SequenceGroupMetadata)
  10. from aphrodite.distributed import broadcast_tensor_dict
  11. from aphrodite.lora.layers import LoRAMapping
  12. from aphrodite.lora.request import LoRARequest
  13. from aphrodite.modeling.pooling_metadata import PoolingMetadata
  14. from aphrodite.task_handler.model_runner import ModelRunner
  15. class EmbeddingModelRunner(ModelRunner):
  16. def __init__(
  17. self,
  18. model_config: ModelConfig,
  19. parallel_config: ParallelConfig,
  20. scheduler_config: SchedulerConfig,
  21. device_config: DeviceConfig,
  22. cache_config: CacheConfig,
  23. load_config: LoadConfig,
  24. lora_config: Optional[LoRAConfig],
  25. kv_cache_dtype: Optional[str] = "auto",
  26. is_driver_worker: bool = False,
  27. vision_language_config: Optional[VisionLanguageConfig] = None,
  28. ):
  29. super().__init__(model_config,
  30. parallel_config,
  31. scheduler_config,
  32. device_config,
  33. cache_config,
  34. load_config,
  35. lora_config=lora_config,
  36. kv_cache_dtype=kv_cache_dtype,
  37. is_driver_worker=is_driver_worker,
  38. vision_language_config=vision_language_config)
  39. @torch.inference_mode()
  40. def execute_model(
  41. self,
  42. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  43. kv_caches: List[torch.Tensor],
  44. ) -> Optional[PoolerOutput]:
  45. (input_tokens, input_positions, attn_metadata, pooling_metadata,
  46. lora_requests, lora_mapping, multi_modal_input
  47. ) = self.prepare_input_tensors(seq_group_metadata_list)
  48. if self.lora_config:
  49. self.set_active_loras(lora_requests, lora_mapping)
  50. # Currently cuda graph is only supported by the decode phase.
  51. prefill_meta = attn_metadata.prefill_metadata
  52. decode_meta = attn_metadata.decode_metadata
  53. if prefill_meta is None and decode_meta.use_cuda_graph:
  54. graph_batch_size = input_tokens.shape[0]
  55. model_executable = self.graph_runners[graph_batch_size]
  56. else:
  57. model_executable = self.model
  58. num_layers = self.model_config.get_num_layers(self.parallel_config)
  59. kv_caches = [None] * num_layers
  60. execute_model_kwargs = {
  61. "input_ids": input_tokens,
  62. "positions": input_positions,
  63. "kv_caches": kv_caches,
  64. "attn_metadata": attn_metadata,
  65. }
  66. if self.vision_language_config:
  67. execute_model_kwargs.update({"image_input": multi_modal_input})
  68. hidden_states = model_executable(**execute_model_kwargs)
  69. return self.model.pooler(hidden_states=hidden_states,
  70. pooling_metadata=pooling_metadata)
  71. def prepare_input_tensors(
  72. self,
  73. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  74. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
  75. Set[LoRARequest], LoRAMapping, torch.Tensor]:
  76. if self.is_driver_worker:
  77. assert seq_group_metadata_list is not None
  78. # Prepare input tensors.
  79. (
  80. input_tokens,
  81. input_positions,
  82. attn_metadata,
  83. seq_lens,
  84. _,
  85. lora_mapping,
  86. lora_requests,
  87. multi_modal_input,
  88. slot_mapping,
  89. num_prefill_tokens,
  90. num_decode_tokens,
  91. num_prefills,
  92. ) = self._prepare_model_input(seq_group_metadata_list)
  93. # Prepare PoolingMetadata
  94. pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
  95. seq_lens)
  96. metadata_dict = {
  97. "input_tokens": input_tokens,
  98. "input_positions": input_positions,
  99. "lora_requests": lora_requests,
  100. "lora_mapping": lora_mapping,
  101. "multi_modal_input": multi_modal_input,
  102. "num_prefill_tokens": num_prefill_tokens,
  103. "num_decode_tokens": num_decode_tokens,
  104. "slot_mapping": slot_mapping,
  105. "num_prefills": num_prefills,
  106. }
  107. if attn_metadata:
  108. metadata_dict.update(attn_metadata.asdict_zerocopy())
  109. broadcast_tensor_dict(metadata_dict, src=0)
  110. else:
  111. metadata_dict = broadcast_tensor_dict(src=0)
  112. input_tokens = metadata_dict.pop("input_tokens")
  113. input_positions = metadata_dict.pop("input_positions")
  114. lora_mapping = metadata_dict.pop("lora_mapping")
  115. lora_requests = metadata_dict.pop("lora_requests")
  116. multi_modal_input = metadata_dict.pop("multi_modal_input")
  117. if metadata_dict:
  118. attn_metadata = self.attn_backend.make_metadata(
  119. **metadata_dict)
  120. else:
  121. attn_metadata = None
  122. pooling_metadata = PoolingMetadata(seq_groups=None,
  123. seq_data=None,
  124. prompt_lens=None)
  125. return (input_tokens, input_positions, attn_metadata, pooling_metadata,
  126. lora_requests, lora_mapping, multi_modal_input)
  127. def _prepare_pooling(
  128. self,
  129. seq_group_metadata_list: List[SequenceGroupMetadata],
  130. prompt_lens: List[int],
  131. ) -> PoolingMetadata:
  132. """Prepare PoolingMetadata for the sequence group metadata list."""
  133. seq_groups: List[Tuple[List[int], PoolingParams]] = []
  134. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  135. seq_ids = list(seq_group_metadata.seq_data.keys())
  136. pooling_params = seq_group_metadata.pooling_params
  137. seq_groups.append((seq_ids, pooling_params))
  138. seq_data: Dict[int, SequenceData] = {}
  139. for seq_group_metadata in seq_group_metadata_list:
  140. seq_data.update(seq_group_metadata.seq_data)
  141. pooling_metadata = PoolingMetadata(
  142. seq_groups=seq_groups,
  143. seq_data=seq_data,
  144. prompt_lens=prompt_lens,
  145. )
  146. return pooling_metadata