registry.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import functools
  2. from array import array
  3. from collections import UserDict
  4. from dataclasses import dataclass
  5. from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
  6. Tuple, Type)
  7. from loguru import logger
  8. from torch import nn
  9. from transformers import PretrainedConfig
  10. from typing_extensions import TypeVar
  11. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  12. from .data import LLMInputs
  13. if TYPE_CHECKING:
  14. from aphrodite.common.config import ModelConfig
  15. from aphrodite.common.sequence import SequenceData
  16. from aphrodite.multimodal import MultiModalDataDict, MultiModalRegistry
  17. C = TypeVar("C", bound=PretrainedConfig)
  18. @dataclass(frozen=True)
  19. class InputContext:
  20. """
  21. Contains information about the model which may be used to
  22. modify the inputs.
  23. """
  24. model_config: "ModelConfig"
  25. """The configuration of the model."""
  26. def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
  27. """
  28. Get the HuggingFace configuration
  29. (:class:`transformers.PretrainedConfig`) of the model,
  30. additionally checking its type.
  31. Raises:
  32. ValueError: If the model is not of the specified type.
  33. """
  34. hf_config = self.model_config.hf_config
  35. if not isinstance(hf_config, hf_config_type):
  36. raise TypeError("Invalid type of HuggingFace config. "
  37. f"Expected type: {hf_config_type}, but "
  38. f"found type: {type(hf_config)}")
  39. return hf_config
  40. N = TypeVar("N", bound=Type[nn.Module])
  41. class DummyDataFactory(Protocol):
  42. def __call__(
  43. self,
  44. ctx: InputContext,
  45. seq_len: int,
  46. mm_counts: Mapping[str, int],
  47. ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
  48. """
  49. Create dummy data to be inputted into the model.
  50. Note:
  51. :data:`InputProcessor` is not applied to the dummy data.
  52. """
  53. ...
  54. class _MultiModalCounts(UserDict):
  55. """
  56. Wraps `mm_counts` for a more informative error message
  57. when attempting to access a plugin that does not exist.
  58. """
  59. def __getitem__(self, key: str) -> int:
  60. try:
  61. return super().__getitem__(key)
  62. except KeyError as exc:
  63. msg = (f"There is no multi-modal plugin with the key: {key}. "
  64. f"Available keys: {set(self.keys())}")
  65. raise KeyError(msg) from exc
  66. InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
  67. """Preprocess the inputs to the model."""
  68. class InputRegistry:
  69. """
  70. A registry to dispatch data processing
  71. according to the target model.
  72. """
  73. def __init__(self) -> None:
  74. self._dummy_factories_by_model_type: Dict[Type[nn.Module],
  75. DummyDataFactory] = {}
  76. self._input_processors_by_model_type: Dict[Type[nn.Module],
  77. InputProcessor] = {}
  78. def _default_dummy_data_factory(
  79. self,
  80. ctx: InputContext,
  81. seq_len: int,
  82. mm_counts: Mapping[str, int],
  83. ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
  84. """
  85. The default dummy data factory represents the longest possible text
  86. that can be inputted to the model.
  87. Note:
  88. :data:`InputProcessor` is not applied to the dummy data.
  89. """
  90. # Avoid circular import
  91. from aphrodite.common.sequence import SequenceData
  92. dummy_seq_data = SequenceData(
  93. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len)
  94. dummy_multi_modal_data = None
  95. return dummy_seq_data, dummy_multi_modal_data
  96. def register_dummy_data(self, factory: DummyDataFactory):
  97. """
  98. Register a dummy data factory to a model class.
  99. During memory profiling, the provided function is invoked to create
  100. dummy data to be inputted into the model. The resulting memory usage
  101. should be an upper bound of what the model would use at inference time.
  102. """
  103. def wrapper(model_cls: N) -> N:
  104. if model_cls in self._dummy_factories_by_model_type:
  105. logger.warning(
  106. f"Model class {model_cls} already has dummy data "
  107. f"registered to {self}. It is overwritten by the new one.")
  108. self._dummy_factories_by_model_type[model_cls] = factory
  109. return model_cls
  110. return wrapper
  111. def dummy_data_for_profiling(
  112. self,
  113. model_config: "ModelConfig",
  114. seq_len: int,
  115. mm_registry: "MultiModalRegistry",
  116. ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
  117. """
  118. Create dummy data for profiling the memory usage of a model.
  119. The model is identified by ``model_config``.
  120. See also:
  121. :ref:`enabling_multimodal_inputs`
  122. Note:
  123. This should be called after
  124. :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
  125. """
  126. # Avoid circular import
  127. from aphrodite.modeling.model_loader import get_model_architecture
  128. model_cls, _ = get_model_architecture(model_config)
  129. dummy_factory = self._dummy_factories_by_model_type \
  130. .get(model_cls, self._default_dummy_data_factory)
  131. mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
  132. seq_data, mm_data = dummy_factory(
  133. InputContext(model_config),
  134. seq_len,
  135. _MultiModalCounts(mm_counts),
  136. )
  137. # Having more tokens is over-conservative but otherwise fine
  138. num_tokens = seq_data.prompt_token_ids
  139. assert len(num_tokens) >= seq_len, (
  140. f"Expected at least {seq_len} dummy tokens for profiling, "
  141. f"but found {len(num_tokens)} tokens instead.")
  142. if mm_data is not None:
  143. for k, v in mm_data.items():
  144. num_items = len(v) if isinstance(v, list) else 1
  145. num_expected = mm_counts[k]
  146. assert num_items >= num_expected, (
  147. f"Expected at least {num_expected} dummy '{k}' instances "
  148. f"for profiling, but found {num_items} instances instead.")
  149. return seq_data, mm_data
  150. def _default_input_processor(self, ctx: InputContext,
  151. inputs: LLMInputs) -> LLMInputs:
  152. """The default input processor is a no-op."""
  153. return inputs
  154. def register_input_processor(self, processor: InputProcessor):
  155. """
  156. Register an input processor to a model class.
  157. The provided function is invoked on each input to the model. This
  158. happens before
  159. :meth:`~aphrodite.multimodal.MultiModalRegistry.map_input`.
  160. See also:
  161. :ref:`input_processing_pipeline`
  162. """
  163. def wrapper(model_cls: N) -> N:
  164. if model_cls in self._input_processors_by_model_type:
  165. logger.warning(
  166. f"Model class {model_cls} already has input processor "
  167. f"registered to {self}. It is overwritten by the new one.")
  168. self._input_processors_by_model_type[model_cls] = processor
  169. return model_cls
  170. return wrapper
  171. def process_input(self, model_config: "ModelConfig",
  172. inputs: LLMInputs) -> LLMInputs:
  173. """
  174. Apply an input processor to an instance of model inputs.
  175. The model is identified by ``model_config``.
  176. See also:
  177. :ref:`input_processing_pipeline`
  178. """
  179. # Avoid circular import
  180. from aphrodite.modeling.model_loader import get_model_architecture
  181. model_cls, _ = get_model_architecture(model_config)
  182. processor = self._input_processors_by_model_type \
  183. .get(model_cls, self._default_input_processor)
  184. return processor(InputContext(model_config), inputs)
  185. def create_input_processor(self, model_config: "ModelConfig"):
  186. """
  187. Create an input processor (see :meth:`process_input`) for a
  188. specific model.
  189. """
  190. return functools.partial(self.process_input, model_config)