1
0

llava.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. import itertools
  2. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  3. TypedDict, Union)
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
  8. from aphrodite.attention import AttentionMetadata
  9. from aphrodite.common.config import CacheConfig, MultiModalConfig
  10. from aphrodite.common.sequence import IntermediateTensors
  11. from aphrodite.common.utils import is_list_of
  12. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  13. from aphrodite.modeling.layers.activation import get_act_fn
  14. from aphrodite.modeling.layers.sampler import SamplerOutput
  15. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.quantization.base_config import QuantizationConfig
  19. from .clip import (CLIPVisionModel, dummy_image_for_clip,
  20. dummy_seq_data_for_clip, get_max_clip_image_tokens,
  21. input_processor_for_clip)
  22. from .interfaces import SupportsMultiModal
  23. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  24. dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
  25. input_processor_for_siglip)
  26. from .utils import (filter_weights, flatten_bn,
  27. init_aphrodite_registered_model,
  28. merge_multimodal_embeddings)
  29. class LlavaImagePixelInputs(TypedDict):
  30. type: Literal["pixel_values"]
  31. data: torch.Tensor
  32. """Shape: `(batch_size * num_images, num_channels, height, width)`"""
  33. class LlavaImageEmbeddingInputs(TypedDict):
  34. type: Literal["image_embeds"]
  35. data: torch.Tensor
  36. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  37. `hidden_size` must match the hidden size of language model backbone.
  38. """
  39. LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
  40. # TODO: Run benchmark and decide if TP.
  41. class LlavaMultiModalProjector(nn.Module):
  42. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  43. projector_hidden_act: str):
  44. super().__init__()
  45. self.linear_1 = nn.Linear(vision_hidden_size,
  46. text_hidden_size,
  47. bias=True)
  48. self.act = get_act_fn(projector_hidden_act)
  49. self.linear_2 = nn.Linear(text_hidden_size,
  50. text_hidden_size,
  51. bias=True)
  52. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  53. hidden_states = self.linear_1(image_features)
  54. hidden_states = self.act(hidden_states)
  55. hidden_states = self.linear_2(hidden_states)
  56. return hidden_states
  57. def get_max_llava_image_tokens(ctx: InputContext):
  58. hf_config = ctx.get_hf_config(LlavaConfig)
  59. vision_config = hf_config.vision_config
  60. if isinstance(vision_config, CLIPVisionConfig):
  61. num_image_tokens = get_max_clip_image_tokens(vision_config)
  62. elif isinstance(vision_config, SiglipVisionConfig):
  63. num_image_tokens = get_max_siglip_image_tokens(vision_config)
  64. else:
  65. msg = f"Unsupported vision config: {type(vision_config)}"
  66. raise NotImplementedError(msg)
  67. strategy = hf_config.vision_feature_select_strategy
  68. if strategy == "default":
  69. return num_image_tokens - 1
  70. elif strategy == "full":
  71. return num_image_tokens
  72. else:
  73. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  74. def dummy_data_for_llava(ctx: InputContext, seq_len: int,
  75. mm_counts: Mapping[str, int]):
  76. hf_config = ctx.get_hf_config(LlavaConfig)
  77. vision_config = hf_config.vision_config
  78. num_images = mm_counts["image"]
  79. image_feature_size = get_max_llava_image_tokens(ctx)
  80. if isinstance(vision_config, CLIPVisionConfig):
  81. seq_data = dummy_seq_data_for_clip(
  82. vision_config,
  83. seq_len,
  84. num_images,
  85. image_token_id=hf_config.image_token_index,
  86. image_feature_size_override=image_feature_size,
  87. )
  88. mm_data = dummy_image_for_clip(vision_config, num_images)
  89. return seq_data, mm_data
  90. elif isinstance(vision_config, SiglipVisionConfig):
  91. seq_data = dummy_seq_data_for_siglip(
  92. vision_config,
  93. seq_len,
  94. num_images,
  95. image_token_id=hf_config.image_token_index,
  96. image_feature_size_override=image_feature_size,
  97. )
  98. mm_data = dummy_image_for_siglip(vision_config, num_images)
  99. return seq_data, mm_data
  100. msg = f"Unsupported vision config: {type(vision_config)}"
  101. raise NotImplementedError(msg)
  102. def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
  103. multi_modal_data = llm_inputs.get("multi_modal_data")
  104. if multi_modal_data is None or "image" not in multi_modal_data:
  105. return llm_inputs
  106. model_config = ctx.model_config
  107. hf_config = ctx.get_hf_config(LlavaConfig)
  108. vision_config = hf_config.vision_config
  109. image_data = multi_modal_data["image"]
  110. if isinstance(image_data, Image.Image):
  111. image_feature_size = get_max_llava_image_tokens(ctx)
  112. elif is_list_of(image_data, Image.Image):
  113. image_feature_size = [get_max_llava_image_tokens(ctx)
  114. ] * len(image_data)
  115. elif isinstance(image_data, torch.Tensor):
  116. num_images, image_feature_size, hidden_size = image_data.shape
  117. elif is_list_of(image_data, torch.Tensor):
  118. image_feature_size = [item.shape[1] for item in image_data]
  119. else:
  120. raise TypeError(f"Invalid image type: {type(image_data)}")
  121. if isinstance(vision_config, CLIPVisionConfig):
  122. return input_processor_for_clip(
  123. model_config,
  124. vision_config,
  125. llm_inputs,
  126. image_token_id=hf_config.image_token_index,
  127. image_feature_size_override=image_feature_size,
  128. )
  129. elif isinstance(vision_config, SiglipVisionConfig):
  130. return input_processor_for_siglip(
  131. model_config,
  132. vision_config,
  133. llm_inputs,
  134. image_token_id=hf_config.image_token_index,
  135. image_feature_size_override=image_feature_size,
  136. )
  137. msg = f"Unsupported vision config: {type(vision_config)}"
  138. raise NotImplementedError(msg)
  139. def _init_vision_tower(hf_config: LlavaConfig):
  140. vision_config = hf_config.vision_config
  141. # Initialize the vision tower only up to the required feature layer
  142. vision_feature_layer = hf_config.vision_feature_layer
  143. if vision_feature_layer < 0:
  144. num_hidden_layers = hf_config.vision_config.num_hidden_layers \
  145. + vision_feature_layer + 1
  146. else:
  147. num_hidden_layers = vision_feature_layer + 1
  148. if isinstance(vision_config, CLIPVisionConfig):
  149. return CLIPVisionModel(
  150. vision_config,
  151. num_hidden_layers_override=num_hidden_layers,
  152. )
  153. elif isinstance(vision_config, SiglipVisionConfig):
  154. return SiglipVisionModel(
  155. vision_config,
  156. num_hidden_layers_override=num_hidden_layers,
  157. )
  158. msg = f"Unsupported vision config: {type(vision_config)}"
  159. raise NotImplementedError(msg)
  160. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  161. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
  162. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
  163. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
  164. class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
  165. def __init__(self,
  166. config: LlavaConfig,
  167. multimodal_config: MultiModalConfig,
  168. cache_config: Optional[CacheConfig] = None,
  169. quant_config: Optional[QuantizationConfig] = None) -> None:
  170. super().__init__()
  171. self.config = config
  172. self.multimodal_config = multimodal_config
  173. # TODO: Optionally initializes this for supporting embeddings.
  174. self.vision_tower = _init_vision_tower(config)
  175. self.multi_modal_projector = LlavaMultiModalProjector(
  176. vision_hidden_size=config.vision_config.hidden_size,
  177. text_hidden_size=config.text_config.hidden_size,
  178. projector_hidden_act=config.projector_hidden_act)
  179. self.language_model = init_aphrodite_registered_model(
  180. config.text_config, cache_config, quant_config)
  181. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  182. h = w = self.config.vision_config.image_size
  183. expected_dims = (3, h, w)
  184. actual_dims = tuple(data.shape[1:])
  185. if actual_dims != expected_dims:
  186. expected_expr = ("batch_size", *map(str, expected_dims))
  187. raise ValueError(
  188. f"The expected shape of pixel values is {expected_expr}. "
  189. f"You supplied {tuple(data.shape)}.")
  190. return data
  191. def _parse_and_validate_image_input(
  192. self, **kwargs: object) -> Optional[LlavaImageInputs]:
  193. pixel_values = kwargs.pop("pixel_values", None)
  194. image_embeds = kwargs.pop("image_embeds", None)
  195. if pixel_values is None and image_embeds is None:
  196. return None
  197. if pixel_values is not None:
  198. if not isinstance(pixel_values, (torch.Tensor, list)):
  199. raise ValueError("Incorrect type of pixel values. "
  200. f"Got type: {type(pixel_values)}")
  201. return LlavaImagePixelInputs(
  202. type="pixel_values",
  203. data=self._validate_pixel_values(
  204. flatten_bn(pixel_values, concat=True)),
  205. )
  206. if image_embeds is not None:
  207. if not isinstance(image_embeds, (torch.Tensor, list)):
  208. raise ValueError("Incorrect type of image embeddings. "
  209. f"Got type: {type(image_embeds)}")
  210. return LlavaImageEmbeddingInputs(
  211. type="image_embeds",
  212. data=flatten_bn(image_embeds, concat=True),
  213. )
  214. raise AssertionError("This line should be unreachable.")
  215. def _select_image_features(self, image_features: torch.Tensor, *,
  216. strategy: str) -> torch.Tensor:
  217. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  218. if strategy == "default":
  219. return image_features[:, 1:]
  220. elif strategy == "full":
  221. return image_features
  222. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  223. def _image_pixels_to_features(
  224. self,
  225. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  226. pixel_values: torch.Tensor,
  227. ) -> torch.Tensor:
  228. # NOTE: we skip the step to select the vision feature layer since
  229. # this is already done inside the vision tower
  230. image_features = vision_tower(pixel_values)
  231. return self._select_image_features(
  232. image_features,
  233. strategy=self.config.vision_feature_select_strategy,
  234. )
  235. def _process_image_pixels(self,
  236. inputs: LlavaImagePixelInputs) -> torch.Tensor:
  237. assert self.vision_tower is not None
  238. pixel_values = inputs["data"]
  239. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  240. def _process_image_input(self,
  241. image_input: LlavaImageInputs) -> torch.Tensor:
  242. if image_input["type"] == "image_embeds":
  243. return image_input["data"]
  244. assert self.vision_tower is not None
  245. image_features = self._process_image_pixels(image_input)
  246. return self.multi_modal_projector(image_features)
  247. def forward(
  248. self,
  249. input_ids: torch.Tensor,
  250. positions: torch.Tensor,
  251. kv_caches: List[torch.Tensor],
  252. attn_metadata: AttentionMetadata,
  253. intermediate_tensors: Optional[IntermediateTensors] = None,
  254. **kwargs: object,
  255. ) -> SamplerOutput:
  256. """Run forward pass for LLaVA-1.5.
  257. One key thing to understand is the `input_ids` already accounts for the
  258. positions of the to-be-inserted image embeddings.
  259. Concretely, consider a text prompt:
  260. `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
  261. Tokenizer outputs:
  262. `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
  263. 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
  264. To reserve space in KV cache, we have to insert placeholder tokens
  265. before they are inputted to the model, so the input processor prepends
  266. additional image tokens (denoted as `32000`), resulting in:
  267. `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
  268. 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
  269. 29901]`.
  270. We insert 575 tokens so that including the original image token in the
  271. input, there are a total of 576 (24 * 24) image tokens, which
  272. corresponds to the number of image tokens inputted to the language
  273. model, i.e. the number of image tokens outputted by the visual encoder.
  274. This way, the `positions` and `attn_metadata` are consistent
  275. with the `input_ids`.
  276. Args:
  277. input_ids: Flattened (concatenated) input_ids corresponding to a
  278. batch.
  279. pixel_values: The pixels in each input image.
  280. See also:
  281. :class:`LlavaImageInputs`
  282. """
  283. image_input = self._parse_and_validate_image_input(**kwargs)
  284. if image_input is not None:
  285. vision_embeddings = self._process_image_input(image_input)
  286. inputs_embeds = self.language_model.model.get_input_embeddings(
  287. input_ids)
  288. inputs_embeds = merge_multimodal_embeddings(
  289. input_ids, inputs_embeds, vision_embeddings,
  290. self.config.image_token_index)
  291. input_ids = None
  292. else:
  293. inputs_embeds = None
  294. hidden_states = self.language_model.model(input_ids,
  295. positions,
  296. kv_caches,
  297. attn_metadata,
  298. None,
  299. inputs_embeds=inputs_embeds)
  300. return hidden_states
  301. def compute_logits(
  302. self,
  303. hidden_states: torch.Tensor,
  304. sampling_metadata: SamplingMetadata,
  305. ) -> Optional[torch.Tensor]:
  306. return self.language_model.compute_logits(hidden_states,
  307. sampling_metadata)
  308. def sample(
  309. self,
  310. logits: torch.Tensor,
  311. sampling_metadata: SamplingMetadata,
  312. ) -> Optional[SamplerOutput]:
  313. return self.language_model.sample(logits, sampling_metadata)
  314. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  315. # prepare weight iterators for components
  316. vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
  317. # load vision encoder
  318. vit_weights = filter_weights(vit_weights, "vision_tower")
  319. self.vision_tower.load_weights(vit_weights)
  320. # load mlp projector
  321. mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
  322. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  323. for name, loaded_weight in mlp_weights:
  324. param = mlp_params_dict[name]
  325. weight_loader = getattr(param, "weight_loader",
  326. default_weight_loader)
  327. weight_loader(param, loaded_weight)
  328. # load llm backbone
  329. llm_weights = filter_weights(llm_weights, "language_model")
  330. self.language_model.load_weights(llm_weights)