1
0

llava.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import itertools
  2. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  3. import torch
  4. import torch.nn as nn
  5. from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
  6. from aphrodite.attention import AttentionMetadata
  7. from aphrodite.common.config import CacheConfig, MultiModalConfig
  8. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  9. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  10. from aphrodite.modeling.layers.activation import get_act_fn
  11. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  12. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  13. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  14. from aphrodite.quantization.base_config import QuantizationConfig
  15. from .clip import (CLIPVisionModel, dummy_image_for_clip,
  16. dummy_seq_data_for_clip, get_max_clip_image_tokens,
  17. input_processor_for_clip)
  18. from .interfaces import SupportsVision
  19. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  20. dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
  21. input_processor_for_siglip)
  22. from .utils import (filter_weights, init_aphrodite_registered_model,
  23. merge_vision_embeddings)
  24. # TODO: Run benchmark and decide if TP.
  25. class LlavaMultiModalProjector(nn.Module):
  26. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  27. projector_hidden_act: str):
  28. super().__init__()
  29. self.linear_1 = nn.Linear(vision_hidden_size,
  30. text_hidden_size,
  31. bias=True)
  32. self.act = get_act_fn(projector_hidden_act)
  33. self.linear_2 = nn.Linear(text_hidden_size,
  34. text_hidden_size,
  35. bias=True)
  36. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  37. hidden_states = self.linear_1(image_features)
  38. hidden_states = self.act(hidden_states)
  39. hidden_states = self.linear_2(hidden_states)
  40. return hidden_states
  41. class LlavaImagePixelInputs(TypedDict):
  42. type: Literal["pixel_values"]
  43. data: torch.Tensor
  44. """Shape: `(batch_size, num_channels, height, width)`"""
  45. LlavaImageInputs = LlavaImagePixelInputs
  46. def get_max_llava_image_tokens(ctx: InputContext):
  47. hf_config = ctx.get_hf_config(LlavaConfig)
  48. vision_config = hf_config.vision_config
  49. if isinstance(vision_config, CLIPVisionConfig):
  50. num_image_tokens = get_max_clip_image_tokens(vision_config)
  51. elif isinstance(vision_config, SiglipVisionConfig):
  52. num_image_tokens = get_max_siglip_image_tokens(vision_config)
  53. else:
  54. msg = f"Unsupported vision config: {type(vision_config)}"
  55. raise NotImplementedError(msg)
  56. strategy = hf_config.vision_feature_select_strategy
  57. if strategy == "default":
  58. return num_image_tokens - 1
  59. elif strategy == "full":
  60. return num_image_tokens
  61. else:
  62. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  63. def dummy_data_for_llava(ctx: InputContext, seq_len: int):
  64. hf_config = ctx.get_hf_config(LlavaConfig)
  65. vision_config = hf_config.vision_config
  66. image_feature_size = get_max_llava_image_tokens(ctx)
  67. if isinstance(vision_config, CLIPVisionConfig):
  68. seq_data = dummy_seq_data_for_clip(
  69. vision_config,
  70. seq_len,
  71. image_token_id=hf_config.image_token_index,
  72. image_feature_size_override=image_feature_size,
  73. )
  74. mm_data = dummy_image_for_clip(vision_config)
  75. return seq_data, mm_data
  76. elif isinstance(vision_config, SiglipVisionConfig):
  77. seq_data = dummy_seq_data_for_siglip(
  78. vision_config,
  79. seq_len,
  80. image_token_id=hf_config.image_token_index,
  81. image_feature_size_override=image_feature_size,
  82. )
  83. mm_data = dummy_image_for_siglip(vision_config)
  84. return seq_data, mm_data
  85. msg = f"Unsupported vision config: {type(vision_config)}"
  86. raise NotImplementedError(msg)
  87. def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
  88. multi_modal_data = llm_inputs.get("multi_modal_data")
  89. if multi_modal_data is None or "image" not in multi_modal_data:
  90. return llm_inputs
  91. model_config = ctx.model_config
  92. hf_config = ctx.get_hf_config(LlavaConfig)
  93. vision_config = hf_config.vision_config
  94. image_feature_size = get_max_llava_image_tokens(ctx)
  95. if isinstance(vision_config, CLIPVisionConfig):
  96. return input_processor_for_clip(
  97. model_config,
  98. vision_config,
  99. llm_inputs,
  100. image_token_id=hf_config.image_token_index,
  101. image_feature_size_override=image_feature_size,
  102. )
  103. elif isinstance(vision_config, SiglipVisionConfig):
  104. return input_processor_for_siglip(
  105. model_config,
  106. vision_config,
  107. llm_inputs,
  108. image_token_id=hf_config.image_token_index,
  109. image_feature_size_override=image_feature_size,
  110. )
  111. msg = f"Unsupported vision config: {type(vision_config)}"
  112. raise NotImplementedError(msg)
  113. def _init_vision_tower(hf_config: LlavaConfig):
  114. vision_config = hf_config.vision_config
  115. # Initialize the vision tower only up to the required feature layer
  116. vision_feature_layer = hf_config.vision_feature_layer
  117. if vision_feature_layer < 0:
  118. num_hidden_layers = hf_config.vision_config.num_hidden_layers \
  119. + vision_feature_layer + 1
  120. else:
  121. num_hidden_layers = vision_feature_layer + 1
  122. if isinstance(vision_config, CLIPVisionConfig):
  123. return CLIPVisionModel(
  124. vision_config,
  125. num_hidden_layers_override=num_hidden_layers,
  126. )
  127. elif isinstance(vision_config, SiglipVisionConfig):
  128. return SiglipVisionModel(
  129. vision_config,
  130. num_hidden_layers_override=num_hidden_layers,
  131. )
  132. msg = f"Unsupported vision config: {type(vision_config)}"
  133. raise NotImplementedError(msg)
  134. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  135. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
  136. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
  137. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
  138. class LlavaForConditionalGeneration(nn.Module, SupportsVision):
  139. def __init__(self,
  140. config: LlavaConfig,
  141. multimodal_config: MultiModalConfig,
  142. cache_config: Optional[CacheConfig] = None,
  143. quant_config: Optional[QuantizationConfig] = None) -> None:
  144. super().__init__()
  145. self.config = config
  146. self.multimodal_config = multimodal_config
  147. # TODO: Optionally initializes this for supporting embeddings.
  148. self.vision_tower = _init_vision_tower(config)
  149. self.multi_modal_projector = LlavaMultiModalProjector(
  150. vision_hidden_size=config.vision_config.hidden_size,
  151. text_hidden_size=config.text_config.hidden_size,
  152. projector_hidden_act=config.projector_hidden_act)
  153. self.language_model = init_aphrodite_registered_model(
  154. config.text_config, cache_config, quant_config)
  155. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  156. h = w = self.config.vision_config.image_size
  157. expected_dims = (3, h, w)
  158. actual_dims = tuple(data.shape[1:])
  159. if actual_dims != expected_dims:
  160. expected_expr = ("batch_size", *map(str, expected_dims))
  161. raise ValueError(
  162. f"The expected shape of pixel values is {expected_expr}. "
  163. f"You supplied {tuple(data.shape)}.")
  164. return data
  165. def _parse_and_validate_image_input(
  166. self, **kwargs: object) -> Optional[LlavaImageInputs]:
  167. pixel_values = kwargs.pop("pixel_values", None)
  168. if pixel_values is None:
  169. return None
  170. if not isinstance(pixel_values, torch.Tensor):
  171. raise ValueError("Incorrect type of pixel values. "
  172. f"Got type: {type(pixel_values)}")
  173. return LlavaImagePixelInputs(
  174. type="pixel_values",
  175. data=self._validate_pixel_values(pixel_values),
  176. )
  177. def _select_image_features(self, image_features: torch.Tensor, *,
  178. strategy: str) -> torch.Tensor:
  179. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  180. if strategy == "default":
  181. return image_features[:, 1:]
  182. elif strategy == "full":
  183. return image_features
  184. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  185. def _image_pixels_to_features(
  186. self,
  187. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  188. pixel_values: torch.Tensor,
  189. ) -> torch.Tensor:
  190. # NOTE: we skip the step to select the vision feature layer since
  191. # this is already done inside the vision tower
  192. image_features = vision_tower(pixel_values)
  193. return self._select_image_features(
  194. image_features,
  195. strategy=self.config.vision_feature_select_strategy,
  196. )
  197. def _process_image_pixels(self,
  198. inputs: LlavaImagePixelInputs) -> torch.Tensor:
  199. assert self.vision_tower is not None
  200. pixel_values = inputs["data"]
  201. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  202. def _process_image_input(self,
  203. image_input: LlavaImageInputs) -> torch.Tensor:
  204. assert self.vision_tower is not None
  205. image_features = self._process_image_pixels(image_input)
  206. return self.multi_modal_projector(image_features)
  207. def forward(
  208. self,
  209. input_ids: torch.Tensor,
  210. positions: torch.Tensor,
  211. kv_caches: List[torch.Tensor],
  212. attn_metadata: AttentionMetadata,
  213. intermediate_tensors: Optional[IntermediateTensors] = None,
  214. **kwargs: object,
  215. ) -> SamplerOutput:
  216. """Run forward pass for LLaVA-1.5.
  217. One key thing to understand is the `input_ids` already accounts for the
  218. positions of the to-be-inserted image embeddings.
  219. Concretely, consider a text prompt:
  220. `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
  221. Tokenizer outputs:
  222. `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
  223. 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
  224. To reserve space in KV cache, we have to insert placeholder tokens
  225. before they are inputted to the model, so the input processor prepends
  226. additional image tokens (denoted as `32000`), resulting in:
  227. `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
  228. 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
  229. 29901]`.
  230. We insert 575 tokens so that including the original image token in the
  231. input, there are a total of 576 (24 * 24) image tokens, which
  232. corresponds to the number of image tokens inputted to the language
  233. model, i.e. the number of image tokens outputted by the visual encoder.
  234. This way, the `positions` and `attn_metadata` are consistent
  235. with the `input_ids`.
  236. Args:
  237. input_ids: Flattened (concatenated) input_ids corresponding to a
  238. batch.
  239. pixel_values: The pixels in each input image.
  240. See also:
  241. :class:`LlavaImageInputs`
  242. """
  243. image_input = self._parse_and_validate_image_input(**kwargs)
  244. if image_input is not None:
  245. vision_embeddings = self._process_image_input(image_input)
  246. inputs_embeds = self.language_model.model.get_input_embeddings(
  247. input_ids)
  248. inputs_embeds = merge_vision_embeddings(
  249. input_ids, inputs_embeds, vision_embeddings,
  250. self.config.image_token_index)
  251. input_ids = None
  252. else:
  253. inputs_embeds = None
  254. hidden_states = self.language_model.model(input_ids,
  255. positions,
  256. kv_caches,
  257. attn_metadata,
  258. None,
  259. inputs_embeds=inputs_embeds)
  260. return hidden_states
  261. def compute_logits(self, hidden_states: torch.Tensor,
  262. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  263. return self.language_model.compute_logits(hidden_states,
  264. sampling_metadata)
  265. def sample(
  266. self,
  267. logits: torch.Tensor,
  268. sampling_metadata: SamplingMetadata,
  269. ) -> Optional[SamplerOutput]:
  270. return self.language_model.sample(logits, sampling_metadata)
  271. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  272. # prepare weight iterators for components
  273. vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
  274. # load vision encoder
  275. vit_weights = filter_weights(vit_weights, "vision_tower")
  276. self.vision_tower.load_weights(vit_weights)
  277. # load mlp projector
  278. mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
  279. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  280. for name, loaded_weight in mlp_weights:
  281. param = mlp_params_dict[name]
  282. weight_loader = getattr(param, "weight_loader",
  283. default_weight_loader)
  284. weight_loader(param, loaded_weight)
  285. # load llm backbone
  286. llm_weights = filter_weights(llm_weights, "language_model")
  287. self.language_model.load_weights(llm_weights)