123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- from abc import ABC, abstractmethod
- from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
- TypeVar)
- from loguru import logger
- from aphrodite.common.config import ModelConfig, VisionLanguageConfig
- if TYPE_CHECKING:
- import torch
- from torch import nn
- class MultiModalData:
- """
- Base class that contains multi-modal data.
- To add a new modality, add a new file under ``multimodal`` directory.
- In this new file, subclass :class:`~MultiModalData` and
- :class:`~MultiModalPlugin`.
- Finally, register the new plugin to
- :const:`aphrodite.multimodal.MULTIMODAL_REGISTRY`.
- This enables models to call :meth:`MultiModalRegistry.register_input` for
- the new modality.
- """
- pass
- D = TypeVar("D", bound=MultiModalData)
- N = TypeVar("N", bound=Type["nn.Module"])
- MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
- Dict[str, "torch.Tensor"]]
- """Return a dictionary to be passed as keyword arguments to
- :meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
- and processors in HuggingFace Transformers."""
- class MultiModalPlugin(ABC, Generic[D]):
- """
- Base class that defines data processing logic for a specific modality.
- In particular, we adopt a registry pattern to dispatch data processing
- according to the model being used (considering that different models may
- process the same data differently). This registry is in turn used by
- :class:`~MultiModalRegistry` which acts at a higher level
- (i.e., the modality of the data).
- """
- @classmethod
- def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
- # Avoid circular import
- from aphrodite.modeling.model_loader import get_model_architecture
- return get_model_architecture(model_config)[0]
- def __init__(self) -> None:
- self._input_processors: Dict[Type["nn.Module"],
- MultiModalInputProcessor[D]] = {}
- @abstractmethod
- def get_data_type(self) -> Type[D]:
- """
- Get the modality (subclass of :class:`~MultiModalData`) served by
- this plugin.
- """
- raise NotImplementedError
- @abstractmethod
- def _default_input_processor(
- self, data: D, model_config: ModelConfig,
- vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
- """Return a dictionary to be passed as keyword arguments to
- :meth:`torch.nn.Module.forward`. This is similar in concept to
- tokenizers and processors in HuggingFace Transformers.
- """
- raise NotImplementedError
- def register_input_processor(self,
- processor: Optional[
- MultiModalInputProcessor[D]] = None):
- """
- Register an input processor to a model class.
-
- When the model receives input data that matches the modality served by
- this plugin (see :meth:`get_data_type`), the provided input processor is
- applied to preprocess the data. If `None` is provided, then the default
- input processor is applied instead.
- """
- def wrapper(model_cls: N) -> N:
- if model_cls in self._input_processors:
- logger.warning(
- "Model class %s already has an input processor "
- "registered to %s. It is overwritten by the new one.",
- model_cls, self)
- self._input_processors[model_cls] = processor \
- or self._default_input_processor
- return model_cls
- return wrapper
- def process_input(
- self, data: D, model_config: ModelConfig,
- vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
- """
- Apply an input processor to a :class:`~MultiModalData` instance passed
- to the model.
-
- The model is identified by ``model_config``. ``vlm_config`` is
- for compatibility purposes and may be merged into ``model_config``
- in the near future.
- """
- model_cls = self.get_model_cls(model_config)
- processor = self._input_processors.get(model_cls)
- if processor is None:
- raise KeyError(f"No input processor in {self} is registered for "
- f"model class {model_cls.__name__}.")
- return processor(data, model_config, vlm_config)
|