1
0

registry.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import functools
  2. from dataclasses import dataclass
  3. from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
  4. TypeVar)
  5. from loguru import logger
  6. from torch import nn
  7. from transformers import PretrainedConfig
  8. from .data import LLMInputs
  9. if TYPE_CHECKING:
  10. from aphrodite.common.config import ModelConfig, MultiModalConfig
  11. from aphrodite.common.sequence import SequenceData
  12. from aphrodite.multimodal import MultiModalDataDict
  13. C = TypeVar("C", bound=PretrainedConfig)
  14. @dataclass(frozen=True)
  15. class InputContext:
  16. """
  17. Contains information about the model which may be used to
  18. modify the inputs.
  19. """
  20. model_config: "ModelConfig"
  21. """The configuration of the model."""
  22. def get_multimodal_config(self) -> "MultiModalConfig":
  23. """
  24. Get the multimodal configuration of the model.
  25. Raises:
  26. ValueError: If the model is not multimodal.
  27. """
  28. multimodal_config = self.model_config.multimodal_config
  29. if multimodal_config is None:
  30. raise ValueError("No multimodal config found")
  31. return multimodal_config
  32. def get_hf_config(self, hf_config_type: Type[C]) -> C:
  33. """
  34. Get the HuggingFace configuration
  35. (:class:`transformers.PretrainedConfig`) of the model,
  36. additionally checking its type.
  37. Raises:
  38. ValueError: If the model is not of the specified type.
  39. """
  40. hf_config = self.model_config.hf_config
  41. if not isinstance(hf_config, hf_config_type):
  42. raise TypeError("Invalid type of HuggingFace config. "
  43. f"Expected type: {hf_config_type}, but "
  44. f"found type: {type(hf_config)}")
  45. return hf_config
  46. N = TypeVar("N", bound=Type[nn.Module])
  47. DummyDataFactory = Callable[[InputContext, int],
  48. Tuple["SequenceData",
  49. Optional["MultiModalDataDict"]]]
  50. """
  51. Create dummy data to be inputted into the model.
  52. Note:
  53. :data:`InputProcessor` is not applied to the dummy data.
  54. """
  55. InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
  56. """Preprocess the inputs to the model."""
  57. class InputRegistry:
  58. """
  59. A registry to dispatch data processing
  60. according to the target model.
  61. """
  62. def __init__(self) -> None:
  63. self._dummy_factories_by_model_type: Dict[Type[nn.Module],
  64. DummyDataFactory] = {}
  65. self._input_processors_by_model_type: Dict[Type[nn.Module],
  66. InputProcessor] = {}
  67. def _default_dummy_data_factory(
  68. self,
  69. ctx: InputContext,
  70. seq_len: int,
  71. ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
  72. """
  73. The default dummy data factory represents the longest possible text
  74. that can be inputted to the model.
  75. Note:
  76. :data:`InputProcessor` is not applied to the dummy data.
  77. """
  78. # Avoid circular import
  79. from aphrodite.common.sequence import SequenceData
  80. dummy_seq_data = SequenceData([0] * seq_len)
  81. dummy_multi_modal_data = None
  82. return dummy_seq_data, dummy_multi_modal_data
  83. def register_dummy_data(self, factory: DummyDataFactory):
  84. """
  85. Register a dummy data factory to a model class.
  86. During memory profiling, the provided function is invoked to create
  87. dummy data to be inputted into the model. The resulting memory usage
  88. should be an upper bound of what the model would use at inference time.
  89. """
  90. def wrapper(model_cls: N) -> N:
  91. if model_cls in self._dummy_factories_by_model_type:
  92. logger.warning(
  93. f"Model class {model_cls} already has dummy data "
  94. f"registered to {self}. It is overwritten by the new one.")
  95. self._dummy_factories_by_model_type[model_cls] = factory
  96. return model_cls
  97. return wrapper
  98. def dummy_data_for_profiling(self, model_config: "ModelConfig",
  99. seq_len: int):
  100. """
  101. Create dummy data for profiling the memory usage of a model.
  102. The model is identified by ``model_config``.
  103. TODO: Add guide.
  104. """
  105. # Avoid circular import
  106. from aphrodite.modeling.model_loader import get_model_architecture
  107. model_cls, _ = get_model_architecture(model_config)
  108. dummy_factory = self._dummy_factories_by_model_type \
  109. .get(model_cls, self._default_dummy_data_factory)
  110. return dummy_factory(InputContext(model_config), seq_len)
  111. def _default_input_processor(self, ctx: InputContext,
  112. inputs: LLMInputs) -> LLMInputs:
  113. """The default input processor is a no-op."""
  114. return inputs
  115. def register_input_processor(self, processor: InputProcessor):
  116. """
  117. Register an input processor to a model class.
  118. The provided function is invoked on each input to the model. This
  119. happens before
  120. :meth:`~aphrodite.multimodal.MultiModalRegistry.map_input`.
  121. See also:
  122. :ref:`input_processing_pipeline`
  123. """
  124. def wrapper(model_cls: N) -> N:
  125. if model_cls in self._input_processors_by_model_type:
  126. logger.warning(
  127. f"Model class {model_cls} already has input processor "
  128. f"registered to {self}. It is overwritten by the new one.")
  129. self._input_processors_by_model_type[model_cls] = processor
  130. return model_cls
  131. return wrapper
  132. def process_input(self, model_config: "ModelConfig",
  133. inputs: LLMInputs) -> LLMInputs:
  134. """
  135. Apply an input processor to an instance of model inputs.
  136. The model is identified by ``model_config``.
  137. See also:
  138. :ref:`input_processing_pipeline`
  139. """
  140. # Avoid circular import
  141. from aphrodite.modeling.model_loader import get_model_architecture
  142. model_cls, _ = get_model_architecture(model_config)
  143. processor = self._input_processors_by_model_type \
  144. .get(model_cls, self._default_input_processor)
  145. return processor(InputContext(model_config), inputs)
  146. def create_input_processor(self, model_config: "ModelConfig"):
  147. """
  148. Create an input processor (see :meth:`process_input`) for a
  149. specific model.
  150. """
  151. return functools.partial(self.process_input, model_config)