123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- import dataclasses
- import pickle
- from abc import ABC, abstractmethod
- from datetime import datetime
- from functools import wraps
- from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
- Optional, Type, TypeVar)
- import torch
- from loguru import logger
- from torch import is_tensor
- from aphrodite.common.sequence import (IntermediateTensors,
- SequenceGroupMetadata)
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.platforms import current_platform
- if TYPE_CHECKING:
- from aphrodite.attention import AttentionMetadata
- from aphrodite.attention.backends.abstract import AttentionBackend
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- T = TypeVar('T', bound="BroadcastableModelInput")
- def _add_attn_metadata_broadcastable_dict(
- tensor_dict: Dict[str, Any],
- attn_metadata: Optional["AttentionMetadata"]) -> None:
- """
- Helper method to update tensor_dict with broadcastable
- AttentionMetadata fields.
- """
- if attn_metadata is not None:
- tensor_dict.update(attn_metadata.asdict_zerocopy())
- def _init_attn_metadata_from_tensor_dict(
- attn_backend: "AttentionBackend",
- tensor_dict: Dict[str, Any],
- ) -> Dict[str, Any]:
- """
- Helper method to initialize AttentionMetadata based on an
- AttentionBackend and broadcastable AttentionMetadata fields.
- """
- # Extract the fields used to create AttentionMetadata.
- valid_attn_kwargs = {}
- for field in dataclasses.fields(attn_backend.get_metadata_cls()):
- val = tensor_dict.pop(field.name, None)
- if val is not None:
- valid_attn_kwargs[field.name] = val
- attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
- tensor_dict["attn_metadata"] = attn_metadata
- return tensor_dict
- def _init_sampling_metadata_from_tensor_dict( # type: ignore
- tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
- """
- Helper method to initialize SamplingMetadata based on broadcastable
- SamplingMetadata fields.
- """
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- selected_token_indices = tensor_dict.pop("selected_token_indices", None)
- # An empty SamplingMetadata to signal that the worker should skip
- # sampling.
- if selected_token_indices is not None:
- tensor_dict["sampling_metadata"] = SamplingMetadata(
- seq_groups=None,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=None,
- num_prompts=0,
- )
- return tensor_dict
- def _add_sampling_metadata_broadcastable_dict(
- tensor_dict: Dict[str, Any],
- sampling_metadata: Optional["SamplingMetadata"]) -> None:
- """
- Helper method to update tensor_dict with broadcastable
- SamplingMetadata fields.
- """
- if sampling_metadata is not None:
- tensor_dict["selected_token_indices"] = (
- sampling_metadata.selected_token_indices)
- def _init_frozen_model_input_from_tensor_dict(
- frozen_model_input_cls: Type["ModelRunnerInputBase"],
- tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
- """
- Helper method to initialize a frozen ModelInput based on broadcastable
- """
- valid_tensor_kwargs = {}
- for field in dataclasses.fields(frozen_model_input_cls):
- val = tensor_dict.pop(field.name, None)
- if val is not None:
- valid_tensor_kwargs[field.name] = val
- frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
- tensor_dict["frozen_model_input"] = frozen_model_input
- return tensor_dict
- def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
- exclude_kwargs: Optional[List[str]] = None):
- def _inner(func):
- @wraps(func)
- def _wrapper(*args, **kwargs):
- try:
- return func(*args, **kwargs)
- except Exception as err:
- timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
- filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
- logger.info("Writing input of failed execution to %s...",
- filename)
- with open(filename, "wb") as filep:
- dumped_inputs = {
- k: v
- for k, v in kwargs.items()
- if k not in (exclude_kwargs or [])
- }
- for i, arg in enumerate(args):
- if i not in (exclude_args or []):
- dumped_inputs[f"arg_{i}"] = arg
- # Only persist dtype and shape for kvcache tensors
- # (can be way too big otherwise)
- if (kv_caches := dumped_inputs.get("kv_caches")) \
- and isinstance(kv_caches, Iterable):
- dumped_inputs["kv_caches"] = [(t.dtype, t.shape)
- for t in kv_caches
- if is_tensor(t)]
- pickle.dump(dumped_inputs, filep)
- logger.info(
- "Completed writing input of failed execution to %s.",
- filename)
- raise type(err)(
- f"Error in model execution (input dumped to {filename}): "
- f"{str(err)}") from err
- return _wrapper
- return _inner
- class BroadcastableModelInput(ABC):
- @abstractmethod
- def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
- """
- Extract broadcastable fields. Override for fields that require some
- custom deserialization.
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def from_broadcasted_tensor_dict(
- cls: Type[T],
- tensor_dict: Dict[str, Any],
- attn_backend: Optional["AttentionBackend"] = None,
- ) -> T:
- """
- Pop fields from the given tensor_dict and populate a new instance of
- BroadcastableModelInput.
- """
- raise NotImplementedError
- @dataclasses.dataclass(frozen=True)
- class ModelRunnerInputBase(BroadcastableModelInput):
- """Local inputs to each worker's model runner. May contain
- device-specific data. Different worker backends may have different methods
- of converting from the global ExecuteModelRequest produced by the LLM
- engine to the worker-local ModelRunnerInputBase objects.
- Model runners that support multi-GPU execution should define a
- ModelRunnerInputBase subclass, add their required fields, and specify how to
- serialize/deserialize a ModelInput for broadcast between workers.
- """
- pass
- class ModelRunnerInputBuilderBase(ABC, Generic[T]):
- """A builder to create ModelRunnerInputBase objects.
- """
- @abstractmethod
- def add_seq_group(self, seq_group_metadata):
- """TBA"""
- raise NotImplementedError
- @abstractmethod
- def build(self, *args, **kwargs) -> T:
- """Build metadata with on-device tensors."""
- raise NotImplementedError
- class ModelRunnerBase(ABC, Generic[T]):
- """
- Model runner interface that abstracts a particular hardware and/or type of
- model. Model execution may communicate data with model runners in other
- processes, but it should not include control plane metadata communication.
- Each ModelRunnerBase subclass should define a corresponding
- ModelRunnerInputBase subclass.
- """
- # Map of request_id -> generator used for seeded random sampling
- generators: Dict[str, torch.Generator] = {}
- @abstractmethod
- def make_model_input_from_broadcasted_tensor_dict(
- self,
- tensor_dict: Dict[str, Any],
- ) -> T:
- """
- Make an instance of a ModelRunnerInputBase from the broadcasted tensor
- dict.
- """
- raise NotImplementedError
- @abstractmethod
- def prepare_model_input(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- virtual_engine: int = 0,
- finished_requests_ids: Optional[List[str]] = None,
- ) -> T:
- """
- Prepare the inputs to ModelRunnerBase.execute_model from an execution
- request. This method may move data to the worker's local device. It is
- not allowed to communicate with other workers or devices.
- """
- raise NotImplementedError
- @current_platform.inference_mode()
- def execute_model(
- self,
- model_input: T,
- kv_caches: Optional[List[torch.Tensor]],
- intermediate_tensors: Optional[IntermediateTensors],
- num_steps: int = 1,
- ) -> Optional[List[SamplerOutput]]:
- """
- Execute the model on the given input.
- """
- raise NotImplementedError
- def get_generators(self, finished_request_ids: Optional[List[str]] = None):
- """
- Return dict of per-request generators used for random sampling.
- """
- # Clean up generators from completed requests
- if finished_request_ids:
- for request_id in finished_request_ids:
- self.generators.pop(request_id, None)
- return self.generators
|