llava.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  2. TypedDict, Union)
  3. import torch
  4. import torch.nn as nn
  5. from PIL import Image
  6. from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, MultiModalConfig
  9. from aphrodite.common.sequence import IntermediateTensors
  10. from aphrodite.common.utils import is_list_of
  11. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.sampler import SamplerOutput
  14. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  15. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  16. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  17. from aphrodite.quantization.base_config import QuantizationConfig
  18. from .clip import (CLIPVisionModel, dummy_image_for_clip,
  19. dummy_seq_data_for_clip, get_max_clip_image_tokens,
  20. input_processor_for_clip)
  21. from .interfaces import SupportsMultiModal
  22. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  23. dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
  24. input_processor_for_siglip)
  25. from .utils import (flatten_bn, group_weights_with_prefix,
  26. init_aphrodite_registered_model,
  27. merge_multimodal_embeddings)
  28. class LlavaImagePixelInputs(TypedDict):
  29. type: Literal["pixel_values"]
  30. data: torch.Tensor
  31. """Shape: `(batch_size * num_images, num_channels, height, width)`"""
  32. class LlavaImageEmbeddingInputs(TypedDict):
  33. type: Literal["image_embeds"]
  34. data: torch.Tensor
  35. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  36. `hidden_size` must match the hidden size of language model backbone.
  37. """
  38. LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
  39. # TODO: Run benchmark and decide if TP.
  40. class LlavaMultiModalProjector(nn.Module):
  41. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  42. projector_hidden_act: str):
  43. super().__init__()
  44. self.linear_1 = nn.Linear(vision_hidden_size,
  45. text_hidden_size,
  46. bias=True)
  47. self.act = get_act_fn(projector_hidden_act)
  48. self.linear_2 = nn.Linear(text_hidden_size,
  49. text_hidden_size,
  50. bias=True)
  51. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  52. hidden_states = self.linear_1(image_features)
  53. hidden_states = self.act(hidden_states)
  54. hidden_states = self.linear_2(hidden_states)
  55. return hidden_states
  56. def get_max_llava_image_tokens(ctx: InputContext):
  57. hf_config = ctx.get_hf_config(LlavaConfig)
  58. vision_config = hf_config.vision_config
  59. if isinstance(vision_config, CLIPVisionConfig):
  60. num_image_tokens = get_max_clip_image_tokens(vision_config)
  61. elif isinstance(vision_config, SiglipVisionConfig):
  62. num_image_tokens = get_max_siglip_image_tokens(vision_config)
  63. else:
  64. msg = f"Unsupported vision config: {type(vision_config)}"
  65. raise NotImplementedError(msg)
  66. strategy = hf_config.vision_feature_select_strategy
  67. if strategy == "default":
  68. return num_image_tokens - 1
  69. elif strategy == "full":
  70. return num_image_tokens
  71. else:
  72. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  73. def dummy_data_for_llava(ctx: InputContext, seq_len: int,
  74. mm_counts: Mapping[str, int]):
  75. hf_config = ctx.get_hf_config(LlavaConfig)
  76. vision_config = hf_config.vision_config
  77. num_images = mm_counts["image"]
  78. image_feature_size = get_max_llava_image_tokens(ctx)
  79. if isinstance(vision_config, CLIPVisionConfig):
  80. seq_data = dummy_seq_data_for_clip(
  81. vision_config,
  82. seq_len,
  83. num_images,
  84. image_token_id=hf_config.image_token_index,
  85. image_feature_size_override=image_feature_size,
  86. )
  87. mm_data = dummy_image_for_clip(vision_config, num_images)
  88. return seq_data, mm_data
  89. elif isinstance(vision_config, SiglipVisionConfig):
  90. seq_data = dummy_seq_data_for_siglip(
  91. vision_config,
  92. seq_len,
  93. num_images,
  94. image_token_id=hf_config.image_token_index,
  95. image_feature_size_override=image_feature_size,
  96. )
  97. mm_data = dummy_image_for_siglip(vision_config, num_images)
  98. return seq_data, mm_data
  99. msg = f"Unsupported vision config: {type(vision_config)}"
  100. raise NotImplementedError(msg)
  101. def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
  102. multi_modal_data = llm_inputs.get("multi_modal_data")
  103. if multi_modal_data is None or "image" not in multi_modal_data:
  104. return llm_inputs
  105. model_config = ctx.model_config
  106. hf_config = ctx.get_hf_config(LlavaConfig)
  107. vision_config = hf_config.vision_config
  108. image_data = multi_modal_data["image"]
  109. if isinstance(image_data, Image.Image):
  110. image_feature_size = get_max_llava_image_tokens(ctx)
  111. elif is_list_of(image_data, Image.Image):
  112. image_feature_size = [get_max_llava_image_tokens(ctx)
  113. ] * len(image_data)
  114. elif isinstance(image_data, torch.Tensor):
  115. num_images, image_feature_size, hidden_size = image_data.shape
  116. elif is_list_of(image_data, torch.Tensor):
  117. image_feature_size = [item.shape[1] for item in image_data]
  118. else:
  119. raise TypeError(f"Invalid image type: {type(image_data)}")
  120. if isinstance(vision_config, CLIPVisionConfig):
  121. return input_processor_for_clip(
  122. model_config,
  123. vision_config,
  124. llm_inputs,
  125. image_token_id=hf_config.image_token_index,
  126. image_feature_size_override=image_feature_size,
  127. )
  128. elif isinstance(vision_config, SiglipVisionConfig):
  129. return input_processor_for_siglip(
  130. model_config,
  131. vision_config,
  132. llm_inputs,
  133. image_token_id=hf_config.image_token_index,
  134. image_feature_size_override=image_feature_size,
  135. )
  136. msg = f"Unsupported vision config: {type(vision_config)}"
  137. raise NotImplementedError(msg)
  138. def _init_vision_tower(hf_config: LlavaConfig):
  139. vision_config = hf_config.vision_config
  140. # Initialize the vision tower only up to the required feature layer
  141. vision_feature_layer = hf_config.vision_feature_layer
  142. if vision_feature_layer < 0:
  143. num_hidden_layers = hf_config.vision_config.num_hidden_layers \
  144. + vision_feature_layer + 1
  145. else:
  146. num_hidden_layers = vision_feature_layer + 1
  147. if isinstance(vision_config, CLIPVisionConfig):
  148. return CLIPVisionModel(
  149. vision_config,
  150. num_hidden_layers_override=num_hidden_layers,
  151. )
  152. elif isinstance(vision_config, SiglipVisionConfig):
  153. return SiglipVisionModel(
  154. vision_config,
  155. num_hidden_layers_override=num_hidden_layers,
  156. )
  157. msg = f"Unsupported vision config: {type(vision_config)}"
  158. raise NotImplementedError(msg)
  159. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  160. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
  161. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
  162. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
  163. class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
  164. def __init__(self,
  165. config: LlavaConfig,
  166. multimodal_config: MultiModalConfig,
  167. cache_config: Optional[CacheConfig] = None,
  168. quant_config: Optional[QuantizationConfig] = None) -> None:
  169. super().__init__()
  170. self.config = config
  171. self.multimodal_config = multimodal_config
  172. # TODO: Optionally initializes this for supporting embeddings.
  173. self.vision_tower = _init_vision_tower(config)
  174. self.multi_modal_projector = LlavaMultiModalProjector(
  175. vision_hidden_size=config.vision_config.hidden_size,
  176. text_hidden_size=config.text_config.hidden_size,
  177. projector_hidden_act=config.projector_hidden_act)
  178. self.language_model = init_aphrodite_registered_model(
  179. config.text_config, cache_config, quant_config)
  180. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  181. h = w = self.config.vision_config.image_size
  182. expected_dims = (3, h, w)
  183. actual_dims = tuple(data.shape[1:])
  184. if actual_dims != expected_dims:
  185. expected_expr = ("batch_size", *map(str, expected_dims))
  186. raise ValueError(
  187. f"The expected shape of pixel values is {expected_expr}. "
  188. f"You supplied {tuple(data.shape)}.")
  189. return data
  190. def _parse_and_validate_image_input(
  191. self, **kwargs: object) -> Optional[LlavaImageInputs]:
  192. pixel_values = kwargs.pop("pixel_values", None)
  193. image_embeds = kwargs.pop("image_embeds", None)
  194. if pixel_values is None and image_embeds is None:
  195. return None
  196. if pixel_values is not None:
  197. if not isinstance(pixel_values, (torch.Tensor, list)):
  198. raise ValueError("Incorrect type of pixel values. "
  199. f"Got type: {type(pixel_values)}")
  200. return LlavaImagePixelInputs(
  201. type="pixel_values",
  202. data=self._validate_pixel_values(
  203. flatten_bn(pixel_values, concat=True)),
  204. )
  205. if image_embeds is not None:
  206. if not isinstance(image_embeds, (torch.Tensor, list)):
  207. raise ValueError("Incorrect type of image embeddings. "
  208. f"Got type: {type(image_embeds)}")
  209. return LlavaImageEmbeddingInputs(
  210. type="image_embeds",
  211. data=flatten_bn(image_embeds, concat=True),
  212. )
  213. raise AssertionError("This line should be unreachable.")
  214. def _select_image_features(self, image_features: torch.Tensor, *,
  215. strategy: str) -> torch.Tensor:
  216. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  217. if strategy == "default":
  218. return image_features[:, 1:]
  219. elif strategy == "full":
  220. return image_features
  221. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  222. def _image_pixels_to_features(
  223. self,
  224. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  225. pixel_values: torch.Tensor,
  226. ) -> torch.Tensor:
  227. # NOTE: we skip the step to select the vision feature layer since
  228. # this is already done inside the vision tower
  229. image_features = vision_tower(pixel_values)
  230. return self._select_image_features(
  231. image_features,
  232. strategy=self.config.vision_feature_select_strategy,
  233. )
  234. def _process_image_pixels(self,
  235. inputs: LlavaImagePixelInputs) -> torch.Tensor:
  236. assert self.vision_tower is not None
  237. pixel_values = inputs["data"]
  238. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  239. def _process_image_input(self,
  240. image_input: LlavaImageInputs) -> torch.Tensor:
  241. if image_input["type"] == "image_embeds":
  242. return image_input["data"]
  243. assert self.vision_tower is not None
  244. image_features = self._process_image_pixels(image_input)
  245. return self.multi_modal_projector(image_features)
  246. def forward(
  247. self,
  248. input_ids: torch.Tensor,
  249. positions: torch.Tensor,
  250. kv_caches: List[torch.Tensor],
  251. attn_metadata: AttentionMetadata,
  252. intermediate_tensors: Optional[IntermediateTensors] = None,
  253. **kwargs: object,
  254. ) -> SamplerOutput:
  255. """Run forward pass for LLaVA-1.5.
  256. One key thing to understand is the `input_ids` already accounts for the
  257. positions of the to-be-inserted image embeddings.
  258. Concretely, consider a text prompt:
  259. `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
  260. Tokenizer outputs:
  261. `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
  262. 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
  263. To reserve space in KV cache, we have to insert placeholder tokens
  264. before they are inputted to the model, so the input processor prepends
  265. additional image tokens (denoted as `32000`), resulting in:
  266. `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
  267. 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
  268. 29901]`.
  269. We insert 575 tokens so that including the original image token in the
  270. input, there are a total of 576 (24 * 24) image tokens, which
  271. corresponds to the number of image tokens inputted to the language
  272. model, i.e. the number of image tokens outputted by the visual encoder.
  273. This way, the `positions` and `attn_metadata` are consistent
  274. with the `input_ids`.
  275. Args:
  276. input_ids: Flattened (concatenated) input_ids corresponding to a
  277. batch.
  278. pixel_values: The pixels in each input image.
  279. See also:
  280. :class:`LlavaImageInputs`
  281. """
  282. image_input = self._parse_and_validate_image_input(**kwargs)
  283. if image_input is not None:
  284. vision_embeddings = self._process_image_input(image_input)
  285. inputs_embeds = self.language_model.model.get_input_embeddings(
  286. input_ids)
  287. inputs_embeds = merge_multimodal_embeddings(
  288. input_ids, inputs_embeds, vision_embeddings,
  289. self.config.image_token_index)
  290. input_ids = None
  291. else:
  292. inputs_embeds = None
  293. hidden_states = self.language_model.model(input_ids,
  294. positions,
  295. kv_caches,
  296. attn_metadata,
  297. None,
  298. inputs_embeds=inputs_embeds)
  299. return hidden_states
  300. def compute_logits(
  301. self,
  302. hidden_states: torch.Tensor,
  303. sampling_metadata: SamplingMetadata,
  304. ) -> Optional[torch.Tensor]:
  305. return self.language_model.compute_logits(hidden_states,
  306. sampling_metadata)
  307. def sample(
  308. self,
  309. logits: torch.Tensor,
  310. sampling_metadata: SamplingMetadata,
  311. ) -> Optional[SamplerOutput]:
  312. return self.language_model.sample(logits, sampling_metadata)
  313. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  314. # prepare weight iterators for components
  315. weights_group = group_weights_with_prefix(weights)
  316. # load vision encoder
  317. self.vision_tower.load_weights(weights_group["vision_tower"])
  318. # load mlp projector
  319. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  320. for name, loaded_weight in weights_group["multi_modal_projector"]:
  321. param = mlp_params_dict[name]
  322. weight_loader = getattr(param, "weight_loader",
  323. default_weight_loader)
  324. weight_loader(param, loaded_weight)
  325. # load llm backbone
  326. self.language_model.load_weights(weights_group["language_model"])