registry.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import functools
  2. from collections import UserDict
  3. from typing import Dict, Mapping, Optional, Sequence
  4. from loguru import logger
  5. from aphrodite.common.config import ModelConfig
  6. from .audio import AudioPlugin
  7. from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
  8. MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
  9. from .image import ImagePlugin
  10. from .video import VideoPlugin
  11. class _MultiModalLimits(UserDict):
  12. """
  13. Wraps `_limits_by_model` for a more informative error message
  14. when attempting to access a model that does not exist.
  15. """
  16. def __getitem__(self, key: ModelConfig) -> Dict[str, int]:
  17. try:
  18. return super().__getitem__(key)
  19. except KeyError as exc:
  20. msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
  21. "forget to call `init_mm_limits_per_prompt`?")
  22. raise KeyError(msg) from exc
  23. class MultiModalRegistry:
  24. """
  25. A registry to dispatch data processing
  26. according to its modality and the target model.
  27. The registry handles both external and internal data input.
  28. """
  29. DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
  30. def __init__(
  31. self,
  32. *,
  33. plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
  34. self._plugins = {p.get_data_key(): p for p in plugins}
  35. # This is used for non-multimodal models
  36. self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
  37. self._limits_by_model = _MultiModalLimits()
  38. def register_plugin(self, plugin: MultiModalPlugin) -> None:
  39. data_type_key = plugin.get_data_key()
  40. if data_type_key in self._plugins:
  41. logger.warning(
  42. "A plugin is already registered for data type "
  43. f"{data_type_key}, "
  44. f"and will be overwritten by the new plugin {plugin}.")
  45. self._plugins[data_type_key] = plugin
  46. def _get_plugin(self, data_type_key: str):
  47. plugin = self._plugins.get(data_type_key)
  48. if plugin is not None:
  49. return plugin
  50. msg = f"Unknown multi-modal data type: {data_type_key}"
  51. raise NotImplementedError(msg)
  52. def register_input_mapper(
  53. self,
  54. data_type_key: str,
  55. mapper: Optional[MultiModalInputMapper] = None,
  56. ):
  57. """
  58. Register an input mapper for a specific modality to a model class.
  59. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
  60. """
  61. return self._get_plugin(data_type_key).register_input_mapper(mapper)
  62. def register_image_input_mapper(
  63. self,
  64. mapper: Optional[MultiModalInputMapper] = None,
  65. ):
  66. """
  67. Register an input mapper for image data to a model class.
  68. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
  69. """
  70. return self.register_input_mapper("image", mapper)
  71. def map_input(self, model_config: ModelConfig,
  72. data: MultiModalDataDict) -> MultiModalInputs:
  73. """
  74. Apply an input mapper to the data passed to the model.
  75. The data belonging to each modality is passed to the corresponding
  76. plugin which in turn converts the data into into keyword arguments
  77. via the input mapper registered for that model.
  78. See :meth:`MultiModalPlugin.map_input` for more details.
  79. Note:
  80. This should be called after :meth:`init_mm_limits_per_prompt`.
  81. """
  82. merged_dict: Dict[str, NestedTensors] = {}
  83. for data_key, data_value in data.items():
  84. plugin = self._get_plugin(data_key)
  85. num_items = len(data_value) if isinstance(data_value, list) else 1
  86. max_items = self._limits_by_model[model_config][data_key]
  87. if num_items > max_items:
  88. raise ValueError(
  89. f"You set {data_key}={max_items} (or defaulted to 1) in "
  90. f"`--limit-mm-per-prompt`, but found {num_items} items "
  91. "in the same prompt.")
  92. input_dict = plugin.map_input(model_config, data_value)
  93. for input_key, input_tensor in input_dict.items():
  94. if input_key in merged_dict:
  95. raise ValueError(f"The input mappers (keys={set(data)}) "
  96. f"resulted in a conflicting keyword "
  97. f"argument to `forward()`: {input_key}")
  98. merged_dict[input_key] = input_tensor
  99. return MultiModalInputs(merged_dict)
  100. def create_input_mapper(self, model_config: ModelConfig):
  101. """
  102. Create an input mapper (see :meth:`map_input`) for a specific model.
  103. """
  104. return functools.partial(self.map_input, model_config)
  105. def register_max_multimodal_tokens(
  106. self,
  107. data_type_key: str,
  108. max_mm_tokens: Optional[MultiModalTokensCalc] = None,
  109. ):
  110. """
  111. Register the maximum number of tokens, corresponding to a single
  112. instance of multimodal data belonging to a specific modality, that are
  113. passed to the language model for a model class.
  114. """
  115. return self._get_plugin(data_type_key) \
  116. .register_max_multimodal_tokens(max_mm_tokens)
  117. def register_max_image_tokens(
  118. self,
  119. max_mm_tokens: Optional[MultiModalTokensCalc] = None,
  120. ):
  121. """
  122. Register the maximum number of image tokens, corresponding to a single
  123. image, that are passed to the language model for a model class.
  124. """
  125. return self.register_max_multimodal_tokens("image", max_mm_tokens)
  126. def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
  127. """
  128. Get the maximum number of multi-modal tokens
  129. for profiling the memory usage of a model.
  130. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
  131. Note:
  132. This should be called after :meth:`init_mm_limits_per_prompt`.
  133. """
  134. limits_per_plugin = self._limits_by_model[model_config]
  135. return sum((limits_per_plugin[key] *
  136. plugin.get_max_multimodal_tokens(model_config))
  137. for key, plugin in self._plugins.items())
  138. def init_mm_limits_per_prompt(
  139. self,
  140. model_config: ModelConfig,
  141. ) -> None:
  142. """
  143. Initialize the maximum number of multi-modal input instances for each
  144. modality that are allowed per prompt for a model class.
  145. """
  146. if model_config in self._limits_by_model:
  147. logger.warning(
  148. f"`mm_limits` has already been set for model="
  149. f"{model_config.model}, and will be overwritten by the "
  150. "new values.")
  151. multimodal_config = model_config.multimodal_config
  152. if multimodal_config is None:
  153. limits_per_plugin = self._disabled_limits_per_plugin
  154. else:
  155. config_limits_per_plugin = multimodal_config.limit_per_prompt
  156. extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
  157. if extra_keys:
  158. logger.warning(
  159. "Detected extra keys in `--limit-mm-per-prompt` which "
  160. f"are not registered as multi-modal plugins: {extra_keys}."
  161. " They will be ignored.")
  162. # NOTE: Currently the default is set to 1 for each plugin
  163. # TODO: Automatically determine the limits based on budget
  164. # once more models support multi-image inputs
  165. limits_per_plugin = {
  166. key: config_limits_per_plugin.get(key, 1)
  167. for key in self._plugins
  168. }
  169. self._limits_by_model[model_config] = limits_per_plugin
  170. def get_mm_limits_per_prompt(
  171. self,
  172. model_config: ModelConfig,
  173. ) -> Mapping[str, int]:
  174. """
  175. Get the maximum number of multi-modal input instances for each modality
  176. that are allowed per prompt for a model class.
  177. Note:
  178. This should be called after :meth:`init_mm_limits_per_prompt`.
  179. """
  180. return self._limits_by_model[model_config]