base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import sys
  2. from abc import ABC, abstractmethod
  3. from collections import UserDict, defaultdict
  4. from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
  5. TypedDict, TypeVar, Union, cast, final)
  6. import numpy as np
  7. import torch
  8. import torch.types
  9. from loguru import logger
  10. from PIL import Image
  11. from torch import nn
  12. from typing_extensions import TypeAlias
  13. from aphrodite.common.config import ModelConfig
  14. from aphrodite.common.utils import (get_allowed_kwarg_only_overrides,
  15. is_list_of, json_map_leaves)
  16. from aphrodite.inputs import InputContext
  17. NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
  18. """
  19. Uses a list instead of a tensor if the dimensions of each element do not match.
  20. """
  21. BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
  22. """
  23. A dictionary containing nested tensors which have been batched via
  24. :meth:`MultiModalInputs.batch`.
  25. """
  26. if sys.version_info < (3, 9):
  27. # UserDict cannot be subscripted
  28. class _MultiModalInputsBase(UserDict):
  29. pass
  30. else:
  31. class _MultiModalInputsBase(UserDict[str, NestedTensors]):
  32. pass
  33. class MultiModalInputs(_MultiModalInputsBase):
  34. """
  35. A dictionary that represents the keyword arguments to
  36. :meth:`~torch.nn.Module.forward`.
  37. """
  38. @staticmethod
  39. def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
  40. """
  41. Recursively stacks lists of tensors when they all have the same shape.
  42. """
  43. if isinstance(nested_tensors, torch.Tensor):
  44. return nested_tensors
  45. stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
  46. if not is_list_of(stacked, torch.Tensor, check="all"):
  47. # Only tensors (not lists) can be stacked.
  48. return stacked
  49. tensors_ = cast(List[torch.Tensor], stacked)
  50. if any(t.shape != tensors_[0].shape for t in tensors_):
  51. # The tensors have incompatible shapes and can't be stacked.
  52. return tensors_
  53. return torch.stack(tensors_)
  54. @staticmethod
  55. def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
  56. """
  57. Batch multiple inputs together into a dictionary.
  58. The resulting dictionary has the same keys as the inputs.
  59. If the corresponding value from each input is a tensor and they all
  60. share the same shape, the output value is a single batched tensor;
  61. otherwise, the output value is a list containing the original value
  62. from each input.
  63. """
  64. if len(inputs_list) == 0:
  65. return {}
  66. item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
  67. for inputs in inputs_list:
  68. # For models that supports multiple modalities (e.g. Qwen2-VL),
  69. # different modalities will return different data keys,
  70. # so batch() should skip the same key check.
  71. for k, v in inputs.items():
  72. item_lists[k].append(v)
  73. return {
  74. k: MultiModalInputs._try_stack(item_list)
  75. for k, item_list in item_lists.items()
  76. } # type: ignore
  77. @staticmethod
  78. def as_kwargs(
  79. batched_inputs: BatchedTensorInputs,
  80. *,
  81. device: torch.types.Device,
  82. ) -> BatchedTensorInputs:
  83. return json_map_leaves(lambda x: x.to(device, non_blocking=True),
  84. batched_inputs)
  85. _T = TypeVar("_T")
  86. MultiModalData: TypeAlias = Union[_T, List[_T]]
  87. """
  88. Either a single data instance, or a list of data instances.
  89. The number of data instances allowed per modality is restricted by
  90. `--limit-mm-per-prompt`.
  91. """
  92. @final
  93. class MultiModalDataBuiltins(TypedDict, total=False):
  94. """Modality types that are predefined by vLLM."""
  95. image: MultiModalData[Image.Image]
  96. """The input image(s)."""
  97. audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
  98. """The input audio item(s) and corresponding sampling rate(s)."""
  99. MultiModalDataDict = Union[MultiModalDataBuiltins,
  100. Mapping[str, MultiModalData[object]]]
  101. """
  102. A dictionary containing an item for each modality type to input.
  103. The data belonging to each modality is converted into keyword arguments
  104. to the model by the corresponding mapper. By default, the mapper of
  105. the corresponding plugin with the same modality key is applied.
  106. """
  107. MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
  108. MultiModalInputs]
  109. """
  110. Return a dictionary to be passed as keyword arguments to
  111. :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
  112. and processors in HuggingFace Transformers.
  113. If the data is not supported, throw :exc:`TypeError`.
  114. """
  115. MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
  116. """
  117. Calculate the maximum number of multimodal tokens input to the language
  118. model. This does not include tokens that correspond to the input text.
  119. """
  120. N = TypeVar("N", bound=Type[nn.Module])
  121. class MultiModalPlugin(ABC):
  122. """
  123. Base class that defines data processing logic for a specific modality.
  124. In particular, we adopt a registry pattern to dispatch data processing
  125. according to the model being used (considering that different models may
  126. process the same data differently). This registry is in turn used by
  127. :class:`~MultiModalRegistry` which acts at a higher level
  128. (i.e., the modality of the data).
  129. """
  130. def __init__(self) -> None:
  131. self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
  132. self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
  133. @abstractmethod
  134. def get_data_key(self) -> str:
  135. """
  136. Get the data key corresponding to the modality.
  137. """
  138. raise NotImplementedError
  139. @abstractmethod
  140. def _default_input_mapper(
  141. self,
  142. ctx: InputContext,
  143. data: MultiModalData[object],
  144. ) -> MultiModalInputs:
  145. """
  146. Return a dictionary to be passed as keyword arguments to
  147. :meth:`~torch.nn.Module.forward`. This is similar in concept to
  148. tokenizers and processors in HuggingFace Transformers.
  149. If the data is not supported, throw :exc:`TypeError`.
  150. """
  151. raise NotImplementedError
  152. def register_input_mapper(
  153. self,
  154. mapper: Optional[MultiModalInputMapper] = None,
  155. ):
  156. """
  157. Register an input mapper to a model class.
  158. When the model receives input data that matches the modality served by
  159. this plugin (see :meth:`get_data_type`), the provided function is
  160. invoked to transform the data into a dictionary of model inputs.
  161. If `None` is provided, then the default input mapper is used instead.
  162. See also:
  163. :ref:`input_processing_pipeline`
  164. :ref:`adding_a_new_multimodal_model`
  165. """
  166. def wrapper(model_cls: N) -> N:
  167. if model_cls in self._input_mappers:
  168. logger.warning(
  169. f"Model class {model_cls} already has an input mapper "
  170. f"registered to {self}. It is overwritten by the new one.")
  171. self._input_mappers[model_cls] = mapper \
  172. or self._default_input_mapper
  173. return model_cls
  174. return wrapper
  175. def map_input(self, model_config: ModelConfig,
  176. data: MultiModalData[object]) -> MultiModalInputs:
  177. """
  178. Apply an input mapper to a data passed
  179. to the model, transforming the data into a dictionary of model inputs.
  180. If the data is not something that the mapper expects, throws TypeError.
  181. The model is identified by ``model_config``.
  182. See also:
  183. :ref:`adding_a_new_multimodal_model`
  184. """
  185. # Avoid circular import
  186. from aphrodite.modeling.model_loader import get_model_architecture
  187. model_cls, _ = get_model_architecture(model_config)
  188. mapper = self._input_mappers.get(model_cls)
  189. # Only get processor kwargs at mapping time if we are not using the
  190. # input mapper; no overrides are used on the default here because they
  191. # should be passed to the huggingface resource at initialization time.
  192. if mapper is not None and mapper != self._default_input_mapper:
  193. mm_processor_kwargs = get_allowed_kwarg_only_overrides(
  194. mapper, overrides=model_config.mm_processor_kwargs)
  195. else:
  196. mm_processor_kwargs = {}
  197. if mapper is None:
  198. raise KeyError(f"No input mapper in {self} is registered for "
  199. f"model class {model_cls.__name__}.")
  200. return mapper(InputContext(model_config), data, **mm_processor_kwargs)
  201. @abstractmethod
  202. def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
  203. """
  204. Calculate the maximum number of tokens, corresponding to a single
  205. instance of multimodal data, that are passed to the language model.
  206. """
  207. raise NotImplementedError
  208. def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
  209. if max_mm_tokens < 1:
  210. raise ValueError("You should set the number of tokens to a "
  211. f"positive integer. Found: {max_mm_tokens}")
  212. def register_max_multimodal_tokens(
  213. self,
  214. max_mm_tokens: Optional[MultiModalTokensCalc] = None,
  215. ):
  216. """
  217. Register the maximum number of tokens, corresponding to a single
  218. instance of multimodal data, that are passed to the language model
  219. for a model class.
  220. If `None` is provided, then the default calculation is used instead.
  221. See also:
  222. :ref:`adding_a_new_multimodal_model`
  223. """
  224. def wrapper(model_cls: N) -> N:
  225. if model_cls in self._max_mm_tokens:
  226. logger.warning(
  227. f"Model class {model_cls} already calculates maximum "
  228. f"number of tokens in {self}. It is overwritten by the "
  229. "new one.")
  230. if isinstance(max_mm_tokens, int):
  231. self._validate_max_multimodal_tokens(max_mm_tokens)
  232. self._max_mm_tokens[model_cls] = max_mm_tokens \
  233. or self._default_max_multimodal_tokens
  234. return model_cls
  235. return wrapper
  236. def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
  237. """
  238. Get the maximum number of multi-modal tokens
  239. for profiling the memory usage of a model.
  240. If this registry is not applicable to the model, `0` is returned.
  241. The model is identified by ``model_config``.
  242. See also:
  243. :ref:`adding_a_new_multimodal_model`
  244. """
  245. # Avoid circular import
  246. from aphrodite.modeling.model_loader import get_model_architecture
  247. model_cls, _ = get_model_architecture(model_config)
  248. if model_cls not in self._input_mappers:
  249. return 0
  250. max_mm_tokens = self._max_mm_tokens.get(model_cls)
  251. if max_mm_tokens is None:
  252. raise KeyError(f"No maximum number of multi-modal tokens is given "
  253. f"for model class {model_cls.__name__} in {self}.")
  254. if callable(max_mm_tokens):
  255. mm_processor_kwargs = get_allowed_kwarg_only_overrides(
  256. max_mm_tokens, overrides=model_config.mm_processor_kwargs)
  257. max_mm_tokens = max_mm_tokens(InputContext(model_config),
  258. **mm_processor_kwargs)
  259. self._validate_max_multimodal_tokens(max_mm_tokens)
  260. return max_mm_tokens