base.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from abc import ABC, abstractmethod
  2. from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
  3. TypeVar)
  4. from loguru import logger
  5. from aphrodite.common.config import ModelConfig, VisionLanguageConfig
  6. if TYPE_CHECKING:
  7. import torch
  8. from torch import nn
  9. class MultiModalData:
  10. """
  11. Base class that contains multi-modal data.
  12. To add a new modality, add a new file under ``multimodal`` directory.
  13. In this new file, subclass :class:`~MultiModalData` and
  14. :class:`~MultiModalPlugin`.
  15. Finally, register the new plugin to
  16. :const:`aphrodite.multimodal.MULTIMODAL_REGISTRY`.
  17. This enables models to call :meth:`MultiModalRegistry.register_input` for
  18. the new modality.
  19. """
  20. pass
  21. D = TypeVar("D", bound=MultiModalData)
  22. N = TypeVar("N", bound=Type["nn.Module"])
  23. MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
  24. Dict[str, "torch.Tensor"]]
  25. """Return a dictionary to be passed as keyword arguments to
  26. :meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
  27. and processors in HuggingFace Transformers."""
  28. class MultiModalPlugin(ABC, Generic[D]):
  29. """
  30. Base class that defines data processing logic for a specific modality.
  31. In particular, we adopt a registry pattern to dispatch data processing
  32. according to the model being used (considering that different models may
  33. process the same data differently). This registry is in turn used by
  34. :class:`~MultiModalRegistry` which acts at a higher level
  35. (i.e., the modality of the data).
  36. """
  37. @classmethod
  38. def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
  39. # Avoid circular import
  40. from aphrodite.modeling.model_loader import get_model_architecture
  41. return get_model_architecture(model_config)[0]
  42. def __init__(self) -> None:
  43. self._input_processors: Dict[Type["nn.Module"],
  44. MultiModalInputProcessor[D]] = {}
  45. @abstractmethod
  46. def get_data_type(self) -> Type[D]:
  47. """
  48. Get the modality (subclass of :class:`~MultiModalData`) served by
  49. this plugin.
  50. """
  51. raise NotImplementedError
  52. @abstractmethod
  53. def _default_input_processor(
  54. self, data: D, model_config: ModelConfig,
  55. vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
  56. """Return a dictionary to be passed as keyword arguments to
  57. :meth:`torch.nn.Module.forward`. This is similar in concept to
  58. tokenizers and processors in HuggingFace Transformers.
  59. """
  60. raise NotImplementedError
  61. def register_input_processor(self,
  62. processor: Optional[
  63. MultiModalInputProcessor[D]] = None):
  64. """
  65. Register an input processor to a model class.
  66. When the model receives input data that matches the modality served by
  67. this plugin (see :meth:`get_data_type`), the provided input processor is
  68. applied to preprocess the data. If `None` is provided, then the default
  69. input processor is applied instead.
  70. """
  71. def wrapper(model_cls: N) -> N:
  72. if model_cls in self._input_processors:
  73. logger.warning(
  74. "Model class %s already has an input processor "
  75. "registered to %s. It is overwritten by the new one.",
  76. model_cls, self)
  77. self._input_processors[model_cls] = processor \
  78. or self._default_input_processor
  79. return model_cls
  80. return wrapper
  81. def process_input(
  82. self, data: D, model_config: ModelConfig,
  83. vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
  84. """
  85. Apply an input processor to a :class:`~MultiModalData` instance passed
  86. to the model.
  87. The model is identified by ``model_config``. ``vlm_config`` is
  88. for compatibility purposes and may be merged into ``model_config``
  89. in the near future.
  90. """
  91. model_cls = self.get_model_cls(model_config)
  92. processor = self._input_processors.get(model_cls)
  93. if processor is None:
  94. raise KeyError(f"No input processor in {self} is registered for "
  95. f"model class {model_cls.__name__}.")
  96. return processor(data, model_config, vlm_config)