registry.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import functools
  2. from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence,
  3. Tuple, Type, TypeVar)
  4. from loguru import logger
  5. from aphrodite.common.config import ModelConfig, VisionLanguageConfig
  6. from .base import MultiModalData, MultiModalPlugin
  7. from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
  8. ImagePixelPlugin)
  9. if TYPE_CHECKING:
  10. import torch
  11. from torch import nn
  12. from aphrodite.common.sequence import SequenceData
  13. D = TypeVar("D", bound=MultiModalData)
  14. N = TypeVar("N", bound=Type["nn.Module"])
  15. MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
  16. Dict[str, "torch.Tensor"]]
  17. MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
  18. Tuple["SequenceData", MultiModalData]]
  19. class MultiModalRegistry:
  20. """
  21. This registry is used by model runners to dispatch data processing
  22. according to its modality and the target model.
  23. """
  24. DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
  25. def __init__(self,
  26. *,
  27. plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS
  28. ) -> None:
  29. self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
  30. self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
  31. MultiModalDummyFactory] = {}
  32. def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
  33. data_type = plugin.get_data_type()
  34. if data_type in self._plugins_by_data_type:
  35. logger.warning(
  36. "A plugin is already registered for data type %s, "
  37. "and will be overwritten by the new plugin %s.", data_type,
  38. plugin)
  39. self._plugins_by_data_type[data_type] = plugin
  40. def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
  41. for typ in data_type.mro():
  42. plugin = self._plugins_by_data_type.get(typ)
  43. if plugin is not None:
  44. return plugin
  45. msg = f"Unknown multi-modal data type: {data_type}"
  46. raise NotImplementedError(msg)
  47. def register_dummy_data(self, factory: MultiModalDummyFactory):
  48. """
  49. Register a dummy data factory to a model class.
  50. During memory profiling, the provided function is invoked to create
  51. dummy data to be inputted into the model. The modality and shape of
  52. the dummy data should be an upper bound of what the model would receive
  53. at inference time.
  54. """
  55. def wrapper(model_cls: N) -> N:
  56. if model_cls in self._dummy_factories_by_model_type:
  57. logger.warning(
  58. "Model class %s already has dummy data "
  59. "registered to %s. It is overwritten by the new one.",
  60. model_cls, self)
  61. self._dummy_factories_by_model_type[model_cls] = factory
  62. return model_cls
  63. return wrapper
  64. def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
  65. vlm_config: VisionLanguageConfig):
  66. """Create dummy data for memory profiling."""
  67. model_cls = MultiModalPlugin.get_model_cls(model_config)
  68. dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
  69. if dummy_factory is None:
  70. msg = f"No dummy data defined for model class: {model_cls}"
  71. raise NotImplementedError(msg)
  72. return dummy_factory(seq_len, model_config, vlm_config)
  73. def register_input(
  74. self,
  75. data_type: Type[D],
  76. processor: Optional[MultiModalInputProcessor[D]] = None):
  77. """
  78. Register an input processor for a specific modality to a model class.
  79. See :meth:`MultiModalPlugin.register_input_processor` for more details.
  80. """
  81. return self._get_plugin_for_data_type(data_type) \
  82. .register_input_processor(processor)
  83. def register_image_pixel_input(
  84. self,
  85. processor: Optional[
  86. MultiModalInputProcessor[ImagePixelData]] = None):
  87. """
  88. Register an input processor for image pixel data to a model class.
  89. See :meth:`MultiModalPlugin.register_input_processor` for more details.
  90. """
  91. return self.register_input(ImagePixelData, processor)
  92. def register_image_feature_input(
  93. self,
  94. processor: Optional[
  95. MultiModalInputProcessor[ImageFeatureData]] = None):
  96. """
  97. Register an input processor for image feature data to a model class.
  98. See :meth:`MultiModalPlugin.register_input_processor` for more details.
  99. """
  100. return self.register_input(ImageFeatureData, processor)
  101. def process_input(self, data: MultiModalData, model_config: ModelConfig,
  102. vlm_config: VisionLanguageConfig):
  103. """
  104. Apply an input processor to a :class:`~MultiModalData` instance passed
  105. to the model.
  106. See :meth:`MultiModalPlugin.process_input` for more details.
  107. """
  108. return self._get_plugin_for_data_type(type(data)) \
  109. .process_input(data, model_config, vlm_config)
  110. def create_input_processor(self, model_config: ModelConfig,
  111. vlm_config: VisionLanguageConfig):
  112. """
  113. Create an input processor (see :meth:`process_input`) for a
  114. specific model.
  115. """
  116. return functools.partial(self.process_input,
  117. model_config=model_config,
  118. vlm_config=vlm_config)
  119. MULTIMODAL_REGISTRY = MultiModalRegistry()
  120. """The global :class:`~MultiModalRegistry` which is used by model runners."""