model_runner_base.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import dataclasses
  2. import pickle
  3. from abc import ABC, abstractmethod
  4. from datetime import datetime
  5. from functools import wraps
  6. from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
  7. Optional, Type, TypeVar)
  8. import torch
  9. from loguru import logger
  10. from torch import is_tensor
  11. from aphrodite.common.sequence import (IntermediateTensors,
  12. SequenceGroupMetadata)
  13. from aphrodite.modeling.layers.sampler import SamplerOutput
  14. from aphrodite.platforms import current_platform
  15. if TYPE_CHECKING:
  16. from aphrodite.attention import AttentionMetadata
  17. from aphrodite.attention.backends.abstract import AttentionBackend
  18. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  19. T = TypeVar('T', bound="BroadcastableModelInput")
  20. def _add_attn_metadata_broadcastable_dict(
  21. tensor_dict: Dict[str, Any],
  22. attn_metadata: Optional["AttentionMetadata"]) -> None:
  23. """
  24. Helper method to update tensor_dict with broadcastable
  25. AttentionMetadata fields.
  26. """
  27. if attn_metadata is not None:
  28. tensor_dict.update(attn_metadata.asdict_zerocopy())
  29. def _init_attn_metadata_from_tensor_dict(
  30. attn_backend: "AttentionBackend",
  31. tensor_dict: Dict[str, Any],
  32. ) -> Dict[str, Any]:
  33. """
  34. Helper method to initialize AttentionMetadata based on an
  35. AttentionBackend and broadcastable AttentionMetadata fields.
  36. """
  37. # Extract the fields used to create AttentionMetadata.
  38. valid_attn_kwargs = {}
  39. for field in dataclasses.fields(attn_backend.get_metadata_cls()):
  40. val = tensor_dict.pop(field.name, None)
  41. if val is not None:
  42. valid_attn_kwargs[field.name] = val
  43. attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
  44. tensor_dict["attn_metadata"] = attn_metadata
  45. return tensor_dict
  46. def _init_sampling_metadata_from_tensor_dict( # type: ignore
  47. tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
  48. """
  49. Helper method to initialize SamplingMetadata based on broadcastable
  50. SamplingMetadata fields.
  51. """
  52. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  53. selected_token_indices = tensor_dict.pop("selected_token_indices", None)
  54. # An empty SamplingMetadata to signal that the worker should skip
  55. # sampling.
  56. if selected_token_indices is not None:
  57. tensor_dict["sampling_metadata"] = SamplingMetadata(
  58. seq_groups=None,
  59. selected_token_indices=selected_token_indices,
  60. categorized_sample_indices=None,
  61. num_prompts=0,
  62. )
  63. return tensor_dict
  64. def _add_sampling_metadata_broadcastable_dict(
  65. tensor_dict: Dict[str, Any],
  66. sampling_metadata: Optional["SamplingMetadata"]) -> None:
  67. """
  68. Helper method to update tensor_dict with broadcastable
  69. SamplingMetadata fields.
  70. """
  71. if sampling_metadata is not None:
  72. tensor_dict["selected_token_indices"] = (
  73. sampling_metadata.selected_token_indices)
  74. def _init_frozen_model_input_from_tensor_dict(
  75. frozen_model_input_cls: Type["ModelRunnerInputBase"],
  76. tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
  77. """
  78. Helper method to initialize a frozen ModelInput based on broadcastable
  79. """
  80. valid_tensor_kwargs = {}
  81. for field in dataclasses.fields(frozen_model_input_cls):
  82. val = tensor_dict.pop(field.name, None)
  83. if val is not None:
  84. valid_tensor_kwargs[field.name] = val
  85. frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
  86. tensor_dict["frozen_model_input"] = frozen_model_input
  87. return tensor_dict
  88. def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
  89. exclude_kwargs: Optional[List[str]] = None):
  90. def _inner(func):
  91. @wraps(func)
  92. def _wrapper(*args, **kwargs):
  93. try:
  94. return func(*args, **kwargs)
  95. except Exception as err:
  96. timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
  97. filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
  98. logger.info("Writing input of failed execution to "
  99. f"{filename}...")
  100. with open(filename, "wb") as filep:
  101. dumped_inputs = {
  102. k: v
  103. for k, v in kwargs.items()
  104. if k not in (exclude_kwargs or [])
  105. }
  106. for i, arg in enumerate(args):
  107. if i not in (exclude_args or []):
  108. dumped_inputs[f"arg_{i}"] = arg
  109. # Only persist dtype and shape for kvcache tensors
  110. # (can be way too big otherwise)
  111. if (kv_caches := dumped_inputs.get("kv_caches")) \
  112. and isinstance(kv_caches, Iterable):
  113. dumped_inputs["kv_caches"] = [(t.dtype, t.shape)
  114. for t in kv_caches
  115. if is_tensor(t)]
  116. try:
  117. pickle.dump(dumped_inputs, filep)
  118. except Exception as pickle_err:
  119. logger.warning(
  120. "Failed to pickle inputs of failed execution: "
  121. f"{str(pickle_err)}")
  122. raise type(err)(f"Error in model execution: "
  123. f"{str(err)}") from err
  124. logger.info(
  125. f"Completed writing input of failed execution to "
  126. f"{filename}.")
  127. raise type(err)(
  128. f"Error in model execution (input dumped to {filename}): "
  129. f"{str(err)}") from err
  130. return _wrapper
  131. return _inner
  132. class BroadcastableModelInput(ABC):
  133. @abstractmethod
  134. def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
  135. """
  136. Extract broadcastable fields. Override for fields that require some
  137. custom deserialization.
  138. """
  139. raise NotImplementedError
  140. @classmethod
  141. @abstractmethod
  142. def from_broadcasted_tensor_dict(
  143. cls: Type[T],
  144. tensor_dict: Dict[str, Any],
  145. attn_backend: Optional["AttentionBackend"] = None,
  146. ) -> T:
  147. """
  148. Pop fields from the given tensor_dict and populate a new instance of
  149. BroadcastableModelInput.
  150. """
  151. raise NotImplementedError
  152. @dataclasses.dataclass(frozen=True)
  153. class ModelRunnerInputBase(BroadcastableModelInput):
  154. """Local inputs to each worker's model runner. May contain
  155. device-specific data. Different worker backends may have different methods
  156. of converting from the global ExecuteModelRequest produced by the LLM
  157. engine to the worker-local ModelRunnerInputBase objects.
  158. Model runners that support multi-GPU execution should define a
  159. ModelRunnerInputBase subclass, add their required fields, and specify how to
  160. serialize/deserialize a ModelInput for broadcast between workers.
  161. """
  162. pass
  163. class ModelRunnerInputBuilderBase(ABC, Generic[T]):
  164. """A builder to create ModelRunnerInputBase objects.
  165. """
  166. @abstractmethod
  167. def add_seq_group(self, seq_group_metadata):
  168. """TBA"""
  169. raise NotImplementedError
  170. @abstractmethod
  171. def build(self, *args, **kwargs) -> T:
  172. """Build metadata with on-device tensors."""
  173. raise NotImplementedError
  174. class ModelRunnerBase(ABC, Generic[T]):
  175. """
  176. Model runner interface that abstracts a particular hardware and/or type of
  177. model. Model execution may communicate data with model runners in other
  178. processes, but it should not include control plane metadata communication.
  179. Each ModelRunnerBase subclass should define a corresponding
  180. ModelRunnerInputBase subclass.
  181. """
  182. # Map of request_id -> generator used for seeded random sampling
  183. generators: Dict[str, torch.Generator] = {}
  184. @abstractmethod
  185. def make_model_input_from_broadcasted_tensor_dict(
  186. self,
  187. tensor_dict: Dict[str, Any],
  188. ) -> T:
  189. """
  190. Make an instance of a ModelRunnerInputBase from the broadcasted tensor
  191. dict.
  192. """
  193. raise NotImplementedError
  194. @abstractmethod
  195. def prepare_model_input(
  196. self,
  197. seq_group_metadata_list: List[SequenceGroupMetadata],
  198. virtual_engine: int = 0,
  199. finished_requests_ids: Optional[List[str]] = None,
  200. ) -> T:
  201. """
  202. Prepare the inputs to ModelRunnerBase.execute_model from an execution
  203. request. This method may move data to the worker's local device. It is
  204. not allowed to communicate with other workers or devices.
  205. """
  206. raise NotImplementedError
  207. @current_platform.inference_mode()
  208. def execute_model(
  209. self,
  210. model_input: T,
  211. kv_caches: Optional[List[torch.Tensor]],
  212. intermediate_tensors: Optional[IntermediateTensors],
  213. num_steps: int = 1,
  214. ) -> Optional[List[SamplerOutput]]:
  215. """
  216. Execute the model on the given input.
  217. """
  218. raise NotImplementedError
  219. def get_generators(self, finished_request_ids: Optional[List[str]] = None):
  220. """
  221. Return dict of per-request generators used for random sampling.
  222. """
  223. # Clean up generators from completed requests
  224. if finished_request_ids:
  225. for request_id in finished_request_ids:
  226. self.generators.pop(request_id, None)
  227. return self.generators