model_runner_base.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import dataclasses
  2. from abc import ABC, abstractmethod
  3. from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
  4. TypeVar)
  5. import torch
  6. from aphrodite.common.sequence import (IntermediateTensors,
  7. SequenceGroupMetadata)
  8. from aphrodite.modeling.layers.sampler import SamplerOutput
  9. from aphrodite.platforms import current_platform
  10. if TYPE_CHECKING:
  11. from aphrodite.attention import AttentionMetadata
  12. from aphrodite.attention.backends.abstract import AttentionBackend
  13. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  14. T = TypeVar('T', bound="BroadcastableModelInput")
  15. def _add_attn_metadata_broadcastable_dict(
  16. tensor_dict: Dict[str, Any],
  17. attn_metadata: Optional["AttentionMetadata"]) -> None:
  18. """
  19. Helper method to update tensor_dict with broadcastable
  20. AttentionMetadata fields.
  21. """
  22. if attn_metadata is not None:
  23. tensor_dict.update(attn_metadata.asdict_zerocopy())
  24. def _init_attn_metadata_from_tensor_dict(
  25. attn_backend: "AttentionBackend",
  26. tensor_dict: Dict[str, Any],
  27. ) -> Dict[str, Any]:
  28. """
  29. Helper method to initialize AttentionMetadata based on an
  30. AttentionBackend and broadcastable AttentionMetadata fields.
  31. """
  32. # Extract the fields used to create AttentionMetadata.
  33. valid_attn_kwargs = {}
  34. for field in dataclasses.fields(attn_backend.get_metadata_cls()):
  35. val = tensor_dict.pop(field.name, None)
  36. if val is not None:
  37. valid_attn_kwargs[field.name] = val
  38. attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
  39. tensor_dict["attn_metadata"] = attn_metadata
  40. return tensor_dict
  41. def _init_sampling_metadata_from_tensor_dict( # type: ignore
  42. tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
  43. """
  44. Helper method to initialize SamplingMetadata based on broadcastable
  45. SamplingMetadata fields.
  46. """
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. selected_token_indices = tensor_dict.pop("selected_token_indices", None)
  49. # An empty SamplingMetadata to signal that the worker should skip
  50. # sampling.
  51. if selected_token_indices is not None:
  52. tensor_dict["sampling_metadata"] = SamplingMetadata(
  53. seq_groups=None,
  54. selected_token_indices=selected_token_indices,
  55. categorized_sample_indices=None,
  56. num_prompts=0,
  57. )
  58. return tensor_dict
  59. def _add_sampling_metadata_broadcastable_dict(
  60. tensor_dict: Dict[str, Any],
  61. sampling_metadata: Optional["SamplingMetadata"]) -> None:
  62. """
  63. Helper method to update tensor_dict with broadcastable
  64. SamplingMetadata fields.
  65. """
  66. if sampling_metadata is not None:
  67. tensor_dict["selected_token_indices"] = (
  68. sampling_metadata.selected_token_indices)
  69. def _init_frozen_model_input_from_tensor_dict(
  70. frozen_model_input_cls: Type["ModelRunnerInputBase"],
  71. tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
  72. """
  73. Helper method to initialize a frozen ModelInput based on broadcastable
  74. """
  75. valid_tensor_kwargs = {}
  76. for field in dataclasses.fields(frozen_model_input_cls):
  77. val = tensor_dict.pop(field.name, None)
  78. if val is not None:
  79. valid_tensor_kwargs[field.name] = val
  80. frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
  81. tensor_dict["frozen_model_input"] = frozen_model_input
  82. return tensor_dict
  83. class BroadcastableModelInput(ABC):
  84. @abstractmethod
  85. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  86. """
  87. Extract broadcastable fields. Override for fields that require some
  88. custom deserialization.
  89. """
  90. raise NotImplementedError
  91. @classmethod
  92. @abstractmethod
  93. def from_broadcasted_tensor_dict(
  94. cls: Type[T],
  95. tensor_dict: Dict[str, Any],
  96. attn_backend: Optional["AttentionBackend"] = None,
  97. ) -> T:
  98. """
  99. Pop fields from the given tensor_dict and populate a new instance of
  100. BroadcastableModelInput.
  101. """
  102. raise NotImplementedError
  103. @dataclasses.dataclass(frozen=True)
  104. class ModelRunnerInputBase(BroadcastableModelInput):
  105. """Local inputs to each worker's model runner. May contain
  106. device-specific data. Different worker backends may have different methods
  107. of converting from the global ExecuteModelRequest produced by the LLM
  108. engine to the worker-local ModelRunnerInputBase objects.
  109. Model runners that support multi-GPU execution should define a
  110. ModelRunnerInputBase subclass, add their required fields, and specify how to
  111. serialize/deserialize a ModelInput for broadcast between workers.
  112. """
  113. pass
  114. class ModelRunnerInputBuilderBase(ABC, Generic[T]):
  115. """A builder to create ModelRunnerInputBase objects.
  116. """
  117. @abstractmethod
  118. def add_seq_group(self, seq_group_metadata):
  119. """TBA"""
  120. raise NotImplementedError
  121. @abstractmethod
  122. def build(self, *args, **kwargs) -> T:
  123. """Build metadata with on-device tensors."""
  124. raise NotImplementedError
  125. class ModelRunnerBase(ABC, Generic[T]):
  126. """
  127. Model runner interface that abstracts a particular hardware and/or type of
  128. model. Model execution may communicate data with model runners in other
  129. processes, but it should not include control plane metadata communication.
  130. Each ModelRunnerBase subclass should define a corresponding
  131. ModelRunnerInputBase subclass.
  132. """
  133. # Map of request_id -> generator used for seeded random sampling
  134. generators: Dict[str, torch.Generator] = {}
  135. @abstractmethod
  136. def make_model_input_from_broadcasted_tensor_dict(
  137. self,
  138. tensor_dict: Dict[str, Any],
  139. ) -> T:
  140. """
  141. Make an instance of a ModelRunnerInputBase from the broadcasted tensor
  142. dict.
  143. """
  144. raise NotImplementedError
  145. @abstractmethod
  146. def prepare_model_input(
  147. self,
  148. seq_group_metadata_list: List[SequenceGroupMetadata],
  149. virtual_engine: int = 0,
  150. finished_requests_ids: Optional[List[str]] = None,
  151. ) -> T:
  152. """
  153. Prepare the inputs to ModelRunnerBase.execute_model from an execution
  154. request. This method may move data to the worker's local device. It is
  155. not allowed to communicate with other workers or devices.
  156. """
  157. raise NotImplementedError
  158. @current_platform.inference_mode()
  159. def execute_model(
  160. self,
  161. model_input: T,
  162. kv_caches: Optional[List[torch.Tensor]],
  163. intermediate_tensors: Optional[IntermediateTensors],
  164. num_steps: int = 1,
  165. ) -> Optional[List[SamplerOutput]]:
  166. """
  167. Execute the model on the given input.
  168. """
  169. raise NotImplementedError
  170. def get_generators(self, finished_request_ids: Optional[List[str]] = None):
  171. """
  172. Return dict of per-request generators used for random sampling.
  173. """
  174. # Clean up generators from completed requests
  175. if finished_request_ids:
  176. for request_id in finished_request_ids:
  177. self.generators.pop(request_id, None)
  178. return self.generators