registry.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import functools
  2. from typing import Dict, Optional, Sequence
  3. import torch
  4. from loguru import logger
  5. from aphrodite.common.config import ModelConfig
  6. from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
  7. MultiModalPlugin, MultiModalTokensCalc)
  8. from .image import ImagePlugin
  9. class MultiModalRegistry:
  10. """
  11. A registry to dispatch data processing
  12. according to its modality and the target model.
  13. The registry handles both external and internal data input.
  14. """
  15. DEFAULT_PLUGINS = (ImagePlugin(), )
  16. def __init__(
  17. self,
  18. *,
  19. plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
  20. self._plugins = {p.get_data_key(): p for p in plugins}
  21. def register_plugin(self, plugin: MultiModalPlugin) -> None:
  22. data_type_key = plugin.get_data_key()
  23. if data_type_key in self._plugins:
  24. logger.warning(
  25. "A plugin is already registered for data type "
  26. f"{data_type_key}, "
  27. f"and will be overwritten by the new plugin {plugin}.")
  28. self._plugins[data_type_key] = plugin
  29. def _get_plugin(self, data_type_key: str):
  30. plugin = self._plugins.get(data_type_key)
  31. if plugin is not None:
  32. return plugin
  33. msg = f"Unknown multi-modal data type: {data_type_key}"
  34. raise NotImplementedError(msg)
  35. def register_input_mapper(
  36. self,
  37. data_type_key: str,
  38. mapper: Optional[MultiModalInputMapper] = None,
  39. ):
  40. """
  41. Register an input mapper for a specific modality to a model class.
  42. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
  43. """
  44. return self._get_plugin(data_type_key).register_input_mapper(mapper)
  45. def register_image_input_mapper(
  46. self,
  47. mapper: Optional[MultiModalInputMapper] = None,
  48. ):
  49. """
  50. Register an input mapper for image data to a model class.
  51. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
  52. """
  53. return self.register_input_mapper("image", mapper)
  54. def map_input(self, model_config: ModelConfig,
  55. data: MultiModalDataDict) -> MultiModalInputs:
  56. """
  57. Apply an input mapper to the data passed to the model.
  58. See :meth:`MultiModalPlugin.map_input` for more details.
  59. """
  60. merged_dict: Dict[str, torch.Tensor] = {}
  61. for data_key, data_value in data.items():
  62. input_dict = self._get_plugin(data_key) \
  63. .map_input(model_config, data_value)
  64. for input_key, input_tensor in input_dict.items():
  65. if input_key in merged_dict:
  66. raise ValueError(f"The input mappers (keys={set(data)}) "
  67. f"resulted in a conflicting keyword "
  68. f"argument to `forward()`: {input_key}")
  69. merged_dict[input_key] = input_tensor
  70. return MultiModalInputs(merged_dict)
  71. def create_input_mapper(self, model_config: ModelConfig):
  72. """
  73. Create an input mapper (see :meth:`map_input`) for a specific model.
  74. """
  75. return functools.partial(self.map_input, model_config)
  76. def register_max_multimodal_tokens(
  77. self,
  78. data_type_key: str,
  79. max_mm_tokens: Optional[MultiModalTokensCalc] = None,
  80. ):
  81. """
  82. Register the maximum number of tokens, belonging to a
  83. specific modality, input to the language model for a model class.
  84. """
  85. return self._get_plugin(data_type_key) \
  86. .register_max_multimodal_tokens(max_mm_tokens)
  87. def register_max_image_tokens(
  88. self,
  89. max_mm_tokens: Optional[MultiModalTokensCalc] = None,
  90. ):
  91. """
  92. Register the maximum number of image tokens
  93. input to the language model for a model class.
  94. """
  95. return self.register_max_multimodal_tokens("image", max_mm_tokens)
  96. def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
  97. """
  98. Get the maximum number of multi-modal tokens
  99. for profiling the memory usage of a model.
  100. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
  101. """
  102. return sum(
  103. plugin.get_max_multimodal_tokens(model_config)
  104. for plugin in self._plugins.values())