model_runner_base.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import dataclasses
  2. from abc import ABC, abstractmethod
  3. from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
  4. TypeVar)
  5. import torch
  6. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  7. SequenceGroupMetadata)
  8. if TYPE_CHECKING:
  9. from aphrodite.attention import AttentionMetadata
  10. from aphrodite.attention.backends.abstract import AttentionBackend
  11. from aphrodite.modeling import SamplingMetadata
  12. T = TypeVar('T', bound="ModelRunnerInputBase")
  13. def _add_attn_metadata_broadcastable_dict(
  14. tensor_dict: Dict[str, Any],
  15. attn_metadata: Optional["AttentionMetadata"]) -> None:
  16. """
  17. Helper method to update tensor_dict with broadcastable
  18. AttentionMetadata fields.
  19. """
  20. if attn_metadata is not None:
  21. tensor_dict.update(attn_metadata.asdict_zerocopy())
  22. def _init_attn_metadata_from_tensor_dict(
  23. attn_backend: "AttentionBackend",
  24. tensor_dict: Dict[str, Any],
  25. ) -> Dict[str, Any]:
  26. """
  27. Helper method to initialize AttentionMetadata based on an
  28. AttentionBackend and broadcastable AttentionMetadata fields.
  29. """
  30. # Extract the fields used to create AttentionMetadata.
  31. valid_attn_kwargs = {}
  32. for field in dataclasses.fields(attn_backend.get_metadata_cls()):
  33. val = tensor_dict.pop(field.name, None)
  34. if val is not None:
  35. valid_attn_kwargs[field.name] = val
  36. attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
  37. tensor_dict["attn_metadata"] = attn_metadata
  38. return tensor_dict
  39. def _init_sampling_metadata_from_tensor_dict( # type: ignore
  40. tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
  41. """
  42. Helper method to initialize SamplingMetadata based on broadcastable
  43. SamplingMetadata fields.
  44. """
  45. from aphrodite.modeling import SamplingMetadata
  46. selected_token_indices = tensor_dict.pop("selected_token_indices", None)
  47. # An empty SamplingMetadata to signal that the worker should skip
  48. # sampling.
  49. if selected_token_indices is not None:
  50. tensor_dict["sampling_metadata"] = SamplingMetadata(
  51. seq_groups=None,
  52. selected_token_indices=selected_token_indices,
  53. categorized_sample_indices=None,
  54. num_prompts=0,
  55. )
  56. return tensor_dict
  57. def _add_sampling_metadata_broadcastable_dict(
  58. tensor_dict: Dict[str, Any],
  59. sampling_metadata: Optional["SamplingMetadata"]) -> None:
  60. """
  61. Helper method to update tensor_dict with broadcastable
  62. SamplingMetadata fields.
  63. """
  64. if sampling_metadata is not None:
  65. tensor_dict["selected_token_indices"] = (
  66. sampling_metadata.selected_token_indices)
  67. @dataclasses.dataclass(frozen=True)
  68. class ModelRunnerInputBase(ABC):
  69. """Local inputs to each worker's model runner. May contain
  70. device-specific data. Different worker backends may have different methods
  71. of converting from the global ExecuteModelRequest produced by the LLM
  72. engine to the worker-local ModelRunnerInputBase objects.
  73. Model runners that support multi-GPU execution should define a
  74. ModelRunnerInputBase subclass, add their required fields, and specify how to
  75. serialize/deserialize a ModelInput for broadcast between workers.
  76. """
  77. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  78. """
  79. Extract broadcastable fields. Override for fields that require some
  80. custom deserialization.
  81. """
  82. raise NotImplementedError
  83. @classmethod
  84. @abstractmethod
  85. def from_broadcasted_tensor_dict(
  86. cls: Type[T],
  87. tensor_dict: Dict[str, Any],
  88. attn_backend: Optional["AttentionBackend"] = None,
  89. ) -> T:
  90. """
  91. Pop fields from the given tensor_dict and populate a new instance of
  92. ModelRunnerInputBase.
  93. """
  94. raise NotImplementedError
  95. class ModelRunnerBase(ABC, Generic[T]):
  96. """
  97. Model runner interface that abstracts a particular hardware and/or type of
  98. model. Model execution may communicate data with model runners in other
  99. processes, but it should not include control plane metadata communication.
  100. Each ModelRunnerBase subclass should define a corresponding
  101. ModelRunnerInputBase subclass.
  102. """
  103. @abstractmethod
  104. def make_model_input_from_broadcasted_tensor_dict(
  105. self,
  106. tensor_dict: Dict[str, Any],
  107. ) -> T:
  108. """
  109. Make an instance of a ModelRunnerInputBase from the broadcasted tensor
  110. dict.
  111. """
  112. raise NotImplementedError
  113. @abstractmethod
  114. def prepare_model_input(
  115. self,
  116. seq_group_metadata_list: List[SequenceGroupMetadata],
  117. virtual_engine: int = 0,
  118. finished_requests_ids: Optional[List[str]] = None,
  119. ) -> T:
  120. """
  121. Prepare the inputs to ModelRunnerBase.execute_model from an execution
  122. request. This method may move data to the worker's local device. It is
  123. not allowed to communicate with other workers or devices.
  124. """
  125. raise NotImplementedError
  126. @torch.inference_mode()
  127. def execute_model(
  128. self,
  129. model_input: T,
  130. kv_caches: Optional[List[torch.Tensor]],
  131. intermediate_tensors: Optional[IntermediateTensors],
  132. num_steps: int = 1,
  133. ) -> Optional[List[SamplerOutput]]:
  134. """
  135. Execute the model on the given input.
  136. """
  137. raise NotImplementedError