1
0

model_runner_base.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import dataclasses
  2. from abc import ABC, abstractmethod
  3. from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
  4. TypeVar)
  5. import torch
  6. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  7. SequenceGroupMetadata)
  8. from aphrodite.platforms import current_platform
  9. if TYPE_CHECKING:
  10. from aphrodite.attention import AttentionMetadata
  11. from aphrodite.attention.backends.abstract import AttentionBackend
  12. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  13. T = TypeVar('T', bound="BroadcastableModelInput")
  14. def _add_attn_metadata_broadcastable_dict(
  15. tensor_dict: Dict[str, Any],
  16. attn_metadata: Optional["AttentionMetadata"]) -> None:
  17. """
  18. Helper method to update tensor_dict with broadcastable
  19. AttentionMetadata fields.
  20. """
  21. if attn_metadata is not None:
  22. tensor_dict.update(attn_metadata.asdict_zerocopy())
  23. def _init_attn_metadata_from_tensor_dict(
  24. attn_backend: "AttentionBackend",
  25. tensor_dict: Dict[str, Any],
  26. ) -> Dict[str, Any]:
  27. """
  28. Helper method to initialize AttentionMetadata based on an
  29. AttentionBackend and broadcastable AttentionMetadata fields.
  30. """
  31. # Extract the fields used to create AttentionMetadata.
  32. valid_attn_kwargs = {}
  33. for field in dataclasses.fields(attn_backend.get_metadata_cls()):
  34. val = tensor_dict.pop(field.name, None)
  35. if val is not None:
  36. valid_attn_kwargs[field.name] = val
  37. attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
  38. tensor_dict["attn_metadata"] = attn_metadata
  39. return tensor_dict
  40. def _init_sampling_metadata_from_tensor_dict( # type: ignore
  41. tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
  42. """
  43. Helper method to initialize SamplingMetadata based on broadcastable
  44. SamplingMetadata fields.
  45. """
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. selected_token_indices = tensor_dict.pop("selected_token_indices", None)
  48. # An empty SamplingMetadata to signal that the worker should skip
  49. # sampling.
  50. if selected_token_indices is not None:
  51. tensor_dict["sampling_metadata"] = SamplingMetadata(
  52. seq_groups=None,
  53. selected_token_indices=selected_token_indices,
  54. categorized_sample_indices=None,
  55. num_prompts=0,
  56. )
  57. return tensor_dict
  58. def _add_sampling_metadata_broadcastable_dict(
  59. tensor_dict: Dict[str, Any],
  60. sampling_metadata: Optional["SamplingMetadata"]) -> None:
  61. """
  62. Helper method to update tensor_dict with broadcastable
  63. SamplingMetadata fields.
  64. """
  65. if sampling_metadata is not None:
  66. tensor_dict["selected_token_indices"] = (
  67. sampling_metadata.selected_token_indices)
  68. def _init_frozen_model_input_from_tensor_dict(
  69. frozen_model_input_cls: Type["ModelRunnerInputBase"],
  70. tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
  71. """
  72. Helper method to initialize a frozen ModelInput based on broadcastable
  73. """
  74. valid_tensor_kwargs = {}
  75. for field in dataclasses.fields(frozen_model_input_cls):
  76. val = tensor_dict.pop(field.name, None)
  77. if val is not None:
  78. valid_tensor_kwargs[field.name] = val
  79. frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
  80. tensor_dict["frozen_model_input"] = frozen_model_input
  81. return tensor_dict
  82. class BroadcastableModelInput(ABC):
  83. @abstractmethod
  84. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  85. """
  86. Extract broadcastable fields. Override for fields that require some
  87. custom deserialization.
  88. """
  89. raise NotImplementedError
  90. @classmethod
  91. @abstractmethod
  92. def from_broadcasted_tensor_dict(
  93. cls: Type[T],
  94. tensor_dict: Dict[str, Any],
  95. attn_backend: Optional["AttentionBackend"] = None,
  96. ) -> T:
  97. """
  98. Pop fields from the given tensor_dict and populate a new instance of
  99. BroadcastableModelInput.
  100. """
  101. raise NotImplementedError
  102. @dataclasses.dataclass(frozen=True)
  103. class ModelRunnerInputBase(BroadcastableModelInput):
  104. """Local inputs to each worker's model runner. May contain
  105. device-specific data. Different worker backends may have different methods
  106. of converting from the global ExecuteModelRequest produced by the LLM
  107. engine to the worker-local ModelRunnerInputBase objects.
  108. Model runners that support multi-GPU execution should define a
  109. ModelRunnerInputBase subclass, add their required fields, and specify how to
  110. serialize/deserialize a ModelInput for broadcast between workers.
  111. """
  112. pass
  113. class ModelRunnerInputBuilderBase(ABC, Generic[T]):
  114. """A builder to create ModelRunnerInputBase objects.
  115. """
  116. @abstractmethod
  117. def add_seq_group(self, seq_group_metadata):
  118. """TBA"""
  119. raise NotImplementedError
  120. @abstractmethod
  121. def build(self, *args, **kwargs) -> T:
  122. """Build metadata with on-device tensors."""
  123. raise NotImplementedError
  124. class ModelRunnerBase(ABC, Generic[T]):
  125. """
  126. Model runner interface that abstracts a particular hardware and/or type of
  127. model. Model execution may communicate data with model runners in other
  128. processes, but it should not include control plane metadata communication.
  129. Each ModelRunnerBase subclass should define a corresponding
  130. ModelRunnerInputBase subclass.
  131. """
  132. # Map of request_id -> generator used for seeded random sampling
  133. generators: Dict[str, torch.Generator] = {}
  134. @abstractmethod
  135. def make_model_input_from_broadcasted_tensor_dict(
  136. self,
  137. tensor_dict: Dict[str, Any],
  138. ) -> T:
  139. """
  140. Make an instance of a ModelRunnerInputBase from the broadcasted tensor
  141. dict.
  142. """
  143. raise NotImplementedError
  144. @abstractmethod
  145. def prepare_model_input(
  146. self,
  147. seq_group_metadata_list: List[SequenceGroupMetadata],
  148. virtual_engine: int = 0,
  149. finished_requests_ids: Optional[List[str]] = None,
  150. ) -> T:
  151. """
  152. Prepare the inputs to ModelRunnerBase.execute_model from an execution
  153. request. This method may move data to the worker's local device. It is
  154. not allowed to communicate with other workers or devices.
  155. """
  156. raise NotImplementedError
  157. @current_platform.inference_mode()
  158. def execute_model(
  159. self,
  160. model_input: T,
  161. kv_caches: Optional[List[torch.Tensor]],
  162. intermediate_tensors: Optional[IntermediateTensors],
  163. num_steps: int = 1,
  164. ) -> Optional[List[SamplerOutput]]:
  165. """
  166. Execute the model on the given input.
  167. """
  168. raise NotImplementedError
  169. def get_generators(self, finished_request_ids: Optional[List[str]] = None):
  170. """
  171. Return dict of per-request generators used for random sampling.
  172. """
  173. # Clean up generators from completed requests
  174. if finished_request_ids:
  175. for request_id in finished_request_ids:
  176. self.generators.pop(request_id, None)
  177. return self.generators