llava.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. import itertools
  2. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  3. import torch
  4. import torch.nn as nn
  5. from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
  6. from aphrodite.attention import AttentionMetadata
  7. from aphrodite.common.config import CacheConfig, MultiModalConfig
  8. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  9. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  10. from aphrodite.modeling.layers.activation import get_act_fn
  11. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  12. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  13. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  14. from aphrodite.quantization.base_config import QuantizationConfig
  15. from .clip import (CLIPVisionModel, dummy_image_for_clip,
  16. dummy_seq_data_for_clip, get_max_clip_image_tokens,
  17. input_processor_for_clip)
  18. from .interfaces import SupportsMultiModal
  19. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  20. dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
  21. input_processor_for_siglip)
  22. from .utils import (filter_weights, init_aphrodite_registered_model,
  23. merge_multimodal_embeddings)
  24. class LlavaImagePixelInputs(TypedDict):
  25. type: Literal["pixel_values"]
  26. data: torch.Tensor
  27. """Shape: `(batch_size, num_channels, height, width)`"""
  28. class LlavaImageEmbeddingInputs(TypedDict):
  29. type: Literal["image_embeds"]
  30. data: torch.Tensor
  31. """Shape: `(batch_size, image_feature_size, hidden_size)`
  32. `hidden_size` must match the hidden size of language model backbone.
  33. """
  34. LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
  35. # TODO: Run benchmark and decide if TP.
  36. class LlavaMultiModalProjector(nn.Module):
  37. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  38. projector_hidden_act: str):
  39. super().__init__()
  40. self.linear_1 = nn.Linear(vision_hidden_size,
  41. text_hidden_size,
  42. bias=True)
  43. self.act = get_act_fn(projector_hidden_act)
  44. self.linear_2 = nn.Linear(text_hidden_size,
  45. text_hidden_size,
  46. bias=True)
  47. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  48. hidden_states = self.linear_1(image_features)
  49. hidden_states = self.act(hidden_states)
  50. hidden_states = self.linear_2(hidden_states)
  51. return hidden_states
  52. def get_max_llava_image_tokens(ctx: InputContext):
  53. hf_config = ctx.get_hf_config(LlavaConfig)
  54. vision_config = hf_config.vision_config
  55. if isinstance(vision_config, CLIPVisionConfig):
  56. num_image_tokens = get_max_clip_image_tokens(vision_config)
  57. elif isinstance(vision_config, SiglipVisionConfig):
  58. num_image_tokens = get_max_siglip_image_tokens(vision_config)
  59. else:
  60. msg = f"Unsupported vision config: {type(vision_config)}"
  61. raise NotImplementedError(msg)
  62. strategy = hf_config.vision_feature_select_strategy
  63. if strategy == "default":
  64. return num_image_tokens - 1
  65. elif strategy == "full":
  66. return num_image_tokens
  67. else:
  68. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  69. def dummy_data_for_llava(ctx: InputContext, seq_len: int):
  70. hf_config = ctx.get_hf_config(LlavaConfig)
  71. vision_config = hf_config.vision_config
  72. image_feature_size = get_max_llava_image_tokens(ctx)
  73. if isinstance(vision_config, CLIPVisionConfig):
  74. seq_data = dummy_seq_data_for_clip(
  75. vision_config,
  76. seq_len,
  77. image_token_id=hf_config.image_token_index,
  78. image_feature_size_override=image_feature_size,
  79. )
  80. mm_data = dummy_image_for_clip(vision_config)
  81. return seq_data, mm_data
  82. elif isinstance(vision_config, SiglipVisionConfig):
  83. seq_data = dummy_seq_data_for_siglip(
  84. vision_config,
  85. seq_len,
  86. image_token_id=hf_config.image_token_index,
  87. image_feature_size_override=image_feature_size,
  88. )
  89. mm_data = dummy_image_for_siglip(vision_config)
  90. return seq_data, mm_data
  91. msg = f"Unsupported vision config: {type(vision_config)}"
  92. raise NotImplementedError(msg)
  93. def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
  94. multi_modal_data = llm_inputs.get("multi_modal_data")
  95. if multi_modal_data is None or "image" not in multi_modal_data:
  96. return llm_inputs
  97. model_config = ctx.model_config
  98. hf_config = ctx.get_hf_config(LlavaConfig)
  99. vision_config = hf_config.vision_config
  100. image_feature_size = get_max_llava_image_tokens(ctx)
  101. if isinstance(vision_config, CLIPVisionConfig):
  102. return input_processor_for_clip(
  103. model_config,
  104. vision_config,
  105. llm_inputs,
  106. image_token_id=hf_config.image_token_index,
  107. image_feature_size_override=image_feature_size,
  108. )
  109. elif isinstance(vision_config, SiglipVisionConfig):
  110. return input_processor_for_siglip(
  111. model_config,
  112. vision_config,
  113. llm_inputs,
  114. image_token_id=hf_config.image_token_index,
  115. image_feature_size_override=image_feature_size,
  116. )
  117. msg = f"Unsupported vision config: {type(vision_config)}"
  118. raise NotImplementedError(msg)
  119. def _init_vision_tower(hf_config: LlavaConfig):
  120. vision_config = hf_config.vision_config
  121. # Initialize the vision tower only up to the required feature layer
  122. vision_feature_layer = hf_config.vision_feature_layer
  123. if vision_feature_layer < 0:
  124. num_hidden_layers = hf_config.vision_config.num_hidden_layers \
  125. + vision_feature_layer + 1
  126. else:
  127. num_hidden_layers = vision_feature_layer + 1
  128. if isinstance(vision_config, CLIPVisionConfig):
  129. return CLIPVisionModel(
  130. vision_config,
  131. num_hidden_layers_override=num_hidden_layers,
  132. )
  133. elif isinstance(vision_config, SiglipVisionConfig):
  134. return SiglipVisionModel(
  135. vision_config,
  136. num_hidden_layers_override=num_hidden_layers,
  137. )
  138. msg = f"Unsupported vision config: {type(vision_config)}"
  139. raise NotImplementedError(msg)
  140. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  141. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
  142. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
  143. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
  144. class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
  145. def __init__(self,
  146. config: LlavaConfig,
  147. multimodal_config: MultiModalConfig,
  148. cache_config: Optional[CacheConfig] = None,
  149. quant_config: Optional[QuantizationConfig] = None) -> None:
  150. super().__init__()
  151. self.config = config
  152. self.multimodal_config = multimodal_config
  153. # TODO: Optionally initializes this for supporting embeddings.
  154. self.vision_tower = _init_vision_tower(config)
  155. self.multi_modal_projector = LlavaMultiModalProjector(
  156. vision_hidden_size=config.vision_config.hidden_size,
  157. text_hidden_size=config.text_config.hidden_size,
  158. projector_hidden_act=config.projector_hidden_act)
  159. self.language_model = init_aphrodite_registered_model(
  160. config.text_config, cache_config, quant_config)
  161. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  162. h = w = self.config.vision_config.image_size
  163. expected_dims = (3, h, w)
  164. actual_dims = tuple(data.shape[1:])
  165. if actual_dims != expected_dims:
  166. expected_expr = ("batch_size", *map(str, expected_dims))
  167. raise ValueError(
  168. f"The expected shape of pixel values is {expected_expr}. "
  169. f"You supplied {tuple(data.shape)}.")
  170. return data
  171. def _parse_and_validate_image_input(
  172. self, **kwargs: object) -> Optional[LlavaImageInputs]:
  173. pixel_values = kwargs.pop("pixel_values", None)
  174. image_embeds = kwargs.pop("image_embeds", None)
  175. if pixel_values is None and image_embeds is None:
  176. return None
  177. if pixel_values is not None:
  178. if not isinstance(pixel_values, torch.Tensor):
  179. raise ValueError("Incorrect type of pixel values. "
  180. f"Got type: {type(pixel_values)}")
  181. return LlavaImagePixelInputs(
  182. type="pixel_values",
  183. data=self._validate_pixel_values(pixel_values),
  184. )
  185. if image_embeds is not None:
  186. if not isinstance(image_embeds, torch.Tensor):
  187. raise ValueError("Incorrect type of image embeddings. "
  188. f"Got type: {type(image_embeds)}")
  189. return LlavaImageEmbeddingInputs(
  190. type="image_embeds",
  191. data=image_embeds,
  192. )
  193. raise AssertionError("This line should be unreachable.")
  194. def _select_image_features(self, image_features: torch.Tensor, *,
  195. strategy: str) -> torch.Tensor:
  196. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  197. if strategy == "default":
  198. return image_features[:, 1:]
  199. elif strategy == "full":
  200. return image_features
  201. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  202. def _image_pixels_to_features(
  203. self,
  204. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  205. pixel_values: torch.Tensor,
  206. ) -> torch.Tensor:
  207. # NOTE: we skip the step to select the vision feature layer since
  208. # this is already done inside the vision tower
  209. image_features = vision_tower(pixel_values)
  210. return self._select_image_features(
  211. image_features,
  212. strategy=self.config.vision_feature_select_strategy,
  213. )
  214. def _process_image_pixels(self,
  215. inputs: LlavaImagePixelInputs) -> torch.Tensor:
  216. assert self.vision_tower is not None
  217. pixel_values = inputs["data"]
  218. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  219. def _process_image_input(self,
  220. image_input: LlavaImageInputs) -> torch.Tensor:
  221. if image_input["type"] == "image_embeds":
  222. return image_input["data"]
  223. assert self.vision_tower is not None
  224. image_features = self._process_image_pixels(image_input)
  225. return self.multi_modal_projector(image_features)
  226. def forward(
  227. self,
  228. input_ids: torch.Tensor,
  229. positions: torch.Tensor,
  230. kv_caches: List[torch.Tensor],
  231. attn_metadata: AttentionMetadata,
  232. intermediate_tensors: Optional[IntermediateTensors] = None,
  233. **kwargs: object,
  234. ) -> SamplerOutput:
  235. """Run forward pass for LLaVA-1.5.
  236. One key thing to understand is the `input_ids` already accounts for the
  237. positions of the to-be-inserted image embeddings.
  238. Concretely, consider a text prompt:
  239. `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
  240. Tokenizer outputs:
  241. `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
  242. 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
  243. To reserve space in KV cache, we have to insert placeholder tokens
  244. before they are inputted to the model, so the input processor prepends
  245. additional image tokens (denoted as `32000`), resulting in:
  246. `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
  247. 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
  248. 29901]`.
  249. We insert 575 tokens so that including the original image token in the
  250. input, there are a total of 576 (24 * 24) image tokens, which
  251. corresponds to the number of image tokens inputted to the language
  252. model, i.e. the number of image tokens outputted by the visual encoder.
  253. This way, the `positions` and `attn_metadata` are consistent
  254. with the `input_ids`.
  255. Args:
  256. input_ids: Flattened (concatenated) input_ids corresponding to a
  257. batch.
  258. pixel_values: The pixels in each input image.
  259. See also:
  260. :class:`LlavaImageInputs`
  261. """
  262. image_input = self._parse_and_validate_image_input(**kwargs)
  263. if image_input is not None:
  264. vision_embeddings = self._process_image_input(image_input)
  265. inputs_embeds = self.language_model.model.get_input_embeddings(
  266. input_ids)
  267. inputs_embeds = merge_multimodal_embeddings(
  268. input_ids, inputs_embeds, vision_embeddings,
  269. self.config.image_token_index)
  270. input_ids = None
  271. else:
  272. inputs_embeds = None
  273. hidden_states = self.language_model.model(input_ids,
  274. positions,
  275. kv_caches,
  276. attn_metadata,
  277. None,
  278. inputs_embeds=inputs_embeds)
  279. return hidden_states
  280. def compute_logits(
  281. self,
  282. hidden_states: torch.Tensor,
  283. sampling_metadata: SamplingMetadata,
  284. ) -> Optional[torch.Tensor]:
  285. return self.language_model.compute_logits(hidden_states,
  286. sampling_metadata)
  287. def sample(
  288. self,
  289. logits: torch.Tensor,
  290. sampling_metadata: SamplingMetadata,
  291. ) -> Optional[SamplerOutput]:
  292. return self.language_model.sample(logits, sampling_metadata)
  293. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  294. # prepare weight iterators for components
  295. vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
  296. # load vision encoder
  297. vit_weights = filter_weights(vit_weights, "vision_tower")
  298. self.vision_tower.load_weights(vit_weights)
  299. # load mlp projector
  300. mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
  301. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  302. for name, loaded_weight in mlp_weights:
  303. param = mlp_params_dict[name]
  304. weight_loader = getattr(param, "weight_loader",
  305. default_weight_loader)
  306. weight_loader(param, loaded_weight)
  307. # load llm backbone
  308. llm_weights = filter_weights(llm_weights, "language_model")
  309. self.language_model.load_weights(llm_weights)