model_runner_base.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import dataclasses
  2. from abc import ABC, abstractmethod
  3. from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
  4. TypeVar)
  5. import torch
  6. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  7. SequenceGroupMetadata)
  8. from aphrodite.platforms import current_platform
  9. if TYPE_CHECKING:
  10. from aphrodite.attention import AttentionMetadata
  11. from aphrodite.attention.backends.abstract import AttentionBackend
  12. from aphrodite.modeling import SamplingMetadata
  13. T = TypeVar('T', bound="ModelRunnerInputBase")
  14. def _add_attn_metadata_broadcastable_dict(
  15. tensor_dict: Dict[str, Any],
  16. attn_metadata: Optional["AttentionMetadata"]) -> None:
  17. """
  18. Helper method to update tensor_dict with broadcastable
  19. AttentionMetadata fields.
  20. """
  21. if attn_metadata is not None:
  22. tensor_dict.update(attn_metadata.asdict_zerocopy())
  23. def _init_attn_metadata_from_tensor_dict(
  24. attn_backend: "AttentionBackend",
  25. tensor_dict: Dict[str, Any],
  26. ) -> Dict[str, Any]:
  27. """
  28. Helper method to initialize AttentionMetadata based on an
  29. AttentionBackend and broadcastable AttentionMetadata fields.
  30. """
  31. # Extract the fields used to create AttentionMetadata.
  32. valid_attn_kwargs = {}
  33. for field in dataclasses.fields(attn_backend.get_metadata_cls()):
  34. val = tensor_dict.pop(field.name, None)
  35. if val is not None:
  36. valid_attn_kwargs[field.name] = val
  37. attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
  38. tensor_dict["attn_metadata"] = attn_metadata
  39. return tensor_dict
  40. def _init_sampling_metadata_from_tensor_dict( # type: ignore
  41. tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
  42. """
  43. Helper method to initialize SamplingMetadata based on broadcastable
  44. SamplingMetadata fields.
  45. """
  46. from aphrodite.modeling import SamplingMetadata
  47. selected_token_indices = tensor_dict.pop("selected_token_indices", None)
  48. # An empty SamplingMetadata to signal that the worker should skip
  49. # sampling.
  50. if selected_token_indices is not None:
  51. tensor_dict["sampling_metadata"] = SamplingMetadata(
  52. seq_groups=None,
  53. selected_token_indices=selected_token_indices,
  54. categorized_sample_indices=None,
  55. num_prompts=0,
  56. )
  57. return tensor_dict
  58. def _add_sampling_metadata_broadcastable_dict(
  59. tensor_dict: Dict[str, Any],
  60. sampling_metadata: Optional["SamplingMetadata"]) -> None:
  61. """
  62. Helper method to update tensor_dict with broadcastable
  63. SamplingMetadata fields.
  64. """
  65. if sampling_metadata is not None:
  66. tensor_dict["selected_token_indices"] = (
  67. sampling_metadata.selected_token_indices)
  68. @dataclasses.dataclass(frozen=True)
  69. class ModelRunnerInputBase(ABC):
  70. """Local inputs to each worker's model runner. May contain
  71. device-specific data. Different worker backends may have different methods
  72. of converting from the global ExecuteModelRequest produced by the LLM
  73. engine to the worker-local ModelRunnerInputBase objects.
  74. Model runners that support multi-GPU execution should define a
  75. ModelRunnerInputBase subclass, add their required fields, and specify how to
  76. serialize/deserialize a ModelInput for broadcast between workers.
  77. """
  78. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  79. """
  80. Extract broadcastable fields. Override for fields that require some
  81. custom deserialization.
  82. """
  83. raise NotImplementedError
  84. @classmethod
  85. @abstractmethod
  86. def from_broadcasted_tensor_dict(
  87. cls: Type[T],
  88. tensor_dict: Dict[str, Any],
  89. attn_backend: Optional["AttentionBackend"] = None,
  90. ) -> T:
  91. """
  92. Pop fields from the given tensor_dict and populate a new instance of
  93. ModelRunnerInputBase.
  94. """
  95. raise NotImplementedError
  96. class ModelRunnerInputBuilderBase(ABC, Generic[T]):
  97. """A builder to create ModelRunnerInputBase objects.
  98. """
  99. @abstractmethod
  100. def add_seq_group(self, seq_group_metadata):
  101. """TBA"""
  102. raise NotImplementedError
  103. @abstractmethod
  104. def build(self, *args, **kwargs) -> T:
  105. """Build metadata with on-device tensors."""
  106. raise NotImplementedError
  107. class ModelRunnerBase(ABC, Generic[T]):
  108. """
  109. Model runner interface that abstracts a particular hardware and/or type of
  110. model. Model execution may communicate data with model runners in other
  111. processes, but it should not include control plane metadata communication.
  112. Each ModelRunnerBase subclass should define a corresponding
  113. ModelRunnerInputBase subclass.
  114. """
  115. # Map of request_id -> generator used for seeded random sampling
  116. generators: Dict[str, torch.Generator] = {}
  117. @abstractmethod
  118. def make_model_input_from_broadcasted_tensor_dict(
  119. self,
  120. tensor_dict: Dict[str, Any],
  121. ) -> T:
  122. """
  123. Make an instance of a ModelRunnerInputBase from the broadcasted tensor
  124. dict.
  125. """
  126. raise NotImplementedError
  127. @abstractmethod
  128. def prepare_model_input(
  129. self,
  130. seq_group_metadata_list: List[SequenceGroupMetadata],
  131. virtual_engine: int = 0,
  132. finished_requests_ids: Optional[List[str]] = None,
  133. ) -> T:
  134. """
  135. Prepare the inputs to ModelRunnerBase.execute_model from an execution
  136. request. This method may move data to the worker's local device. It is
  137. not allowed to communicate with other workers or devices.
  138. """
  139. raise NotImplementedError
  140. @current_platform.inference_mode()
  141. def execute_model(
  142. self,
  143. model_input: T,
  144. kv_caches: Optional[List[torch.Tensor]],
  145. intermediate_tensors: Optional[IntermediateTensors],
  146. num_steps: int = 1,
  147. ) -> Optional[List[SamplerOutput]]:
  148. """
  149. Execute the model on the given input.
  150. """
  151. raise NotImplementedError
  152. def get_generators(self, finished_request_ids: Optional[List[str]] = None):
  153. """
  154. Return dict of per-request generators used for random sampling.
  155. """
  156. # Clean up generators from completed requests
  157. if finished_request_ids:
  158. for request_id in finished_request_ids:
  159. self.generators.pop(request_id, None)
  160. return self.generators