llava.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  2. import torch
  3. import torch.nn as nn
  4. from transformers import CLIPVisionConfig, LlavaConfig
  5. from aphrodite.attention import AttentionMetadata
  6. from aphrodite.common.config import CacheConfig, VisionLanguageConfig
  7. from aphrodite.inputs import INPUT_REGISTRY, InputContext
  8. from aphrodite.modeling.layers.activation import get_act_fn
  9. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  10. from aphrodite.quantization.base_config import (QuantizationConfig)
  11. from aphrodite.modeling.layers.sampler import Sampler
  12. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  13. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  14. from aphrodite.modeling.models.clip import CLIPVisionModel
  15. from aphrodite.modeling.models.llama import LlamaModel
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.common.sequence import SamplerOutput
  19. from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
  20. from .interfaces import SupportsVision
  21. _KEYS_TO_MODIFY_MAPPING = {
  22. "language_model.lm_head": "lm_head",
  23. "language_model.model": "language_model",
  24. }
  25. # TODO: Run benchmark and decide if TP.
  26. class LlavaMultiModalProjector(nn.Module):
  27. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  28. projector_hidden_act: str):
  29. super().__init__()
  30. self.linear_1 = nn.Linear(vision_hidden_size,
  31. text_hidden_size,
  32. bias=True)
  33. self.act = get_act_fn(projector_hidden_act)
  34. self.linear_2 = nn.Linear(text_hidden_size,
  35. text_hidden_size,
  36. bias=True)
  37. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  38. hidden_states = self.linear_1(image_features)
  39. hidden_states = self.act(hidden_states)
  40. hidden_states = self.linear_2(hidden_states)
  41. return hidden_states
  42. def merge_vision_embeddings(input_ids: torch.Tensor,
  43. inputs_embeds: torch.Tensor,
  44. vision_embeddings: torch.Tensor,
  45. image_token_id: int) -> torch.Tensor:
  46. """In place merges in vision_embeddings with inputs_embeds."""
  47. mask = (input_ids == image_token_id)
  48. image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
  49. if mask.sum() != image_feature_size:
  50. raise ValueError(f"image_feature_size should be {image_feature_size}, "
  51. f"but found: {mask.sum()}")
  52. inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
  53. vision_embeddings.shape[-1])
  54. return inputs_embeds
  55. class LlavaImagePixelInputs(TypedDict):
  56. type: Literal["pixel_values"]
  57. data: torch.Tensor
  58. """Shape: (batch_size, num_channels, height, width)"""
  59. LlavaImageInputs = LlavaImagePixelInputs
  60. def dummy_data_for_llava(ctx: InputContext, seq_len: int):
  61. hf_config = ctx.get_hf_config(LlavaConfig)
  62. vision_config = hf_config.vision_config
  63. if isinstance(vision_config, CLIPVisionConfig):
  64. seq_data = dummy_seq_data_for_clip(
  65. vision_config,
  66. seq_len,
  67. image_token_id=hf_config.image_token_index,
  68. )
  69. mm_data = dummy_image_for_clip(vision_config)
  70. return seq_data, mm_data
  71. msg = f"Unsupported vision config: {type(vision_config)}"
  72. raise NotImplementedError(msg)
  73. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  74. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
  75. class LlavaForConditionalGeneration(nn.Module, SupportsVision):
  76. def __init__(self,
  77. config: LlavaConfig,
  78. vlm_config: VisionLanguageConfig,
  79. cache_config: Optional[CacheConfig] = None,
  80. quant_config: Optional[QuantizationConfig] = None) -> None:
  81. super().__init__()
  82. self.config = config
  83. self.vlm_config = vlm_config
  84. # TODO: Optionally initializes this for supporting embeddings.
  85. self.vision_tower = CLIPVisionModel(config.vision_config)
  86. self.multi_modal_projector = LlavaMultiModalProjector(
  87. vision_hidden_size=config.vision_config.hidden_size,
  88. text_hidden_size=config.text_config.hidden_size,
  89. projector_hidden_act=config.projector_hidden_act)
  90. self.quant_config = quant_config
  91. self.language_model = LlamaModel(config.text_config, cache_config,
  92. quant_config)
  93. self.unpadded_vocab_size = config.text_config.vocab_size
  94. self.lm_head = ParallelLMHead(
  95. self.unpadded_vocab_size,
  96. config.text_config.hidden_size,
  97. org_num_embeddings=self.language_model.org_vocab_size)
  98. logit_scale = getattr(config, "logit_scale", 1.0)
  99. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  100. config.vocab_size, logit_scale)
  101. self.sampler = Sampler()
  102. def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
  103. if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
  104. raise ValueError(
  105. f"The expected image tensor shape is batch dimension plus "
  106. f"{self.vlm_config.image_input_shape[1:]}. "
  107. f"You supplied {data.shape}. "
  108. f"If you are using Aphrodite's entrypoint, make sure your "
  109. f"supplied image input is consistent with "
  110. f"image_input_shape in engine args.")
  111. return data
  112. def _parse_and_validate_image_input(
  113. self, **kwargs: object) -> Optional[LlavaImageInputs]:
  114. pixel_values = kwargs.pop("pixel_values", None)
  115. if pixel_values is None:
  116. return None
  117. if not isinstance(pixel_values, torch.Tensor):
  118. raise ValueError("Incorrect type of pixel values. "
  119. f"Got type: {type(pixel_values)}")
  120. return LlavaImagePixelInputs(
  121. type="pixel_values",
  122. data=self._validate_image_data(pixel_values),
  123. )
  124. def _select_image_features(self, image_features: torch.Tensor, *,
  125. strategy: str) -> torch.Tensor:
  126. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  127. if strategy == "default":
  128. return image_features[:, 1:]
  129. elif strategy == "full":
  130. return image_features
  131. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  132. def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
  133. pixel_values: torch.Tensor) -> torch.Tensor:
  134. # NOTE: we skip the step to select the vision feature layer since
  135. # this is already done inside the vision tower
  136. image_features = vision_tower(pixel_values,
  137. self.config.vision_feature_layer)
  138. return self._select_image_features(
  139. image_features,
  140. strategy=self.config.vision_feature_select_strategy,
  141. )
  142. def _process_image_pixels(self,
  143. inputs: LlavaImagePixelInputs) -> torch.Tensor:
  144. assert self.vision_tower is not None
  145. pixel_values = inputs["data"]
  146. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  147. def _process_image_input(self,
  148. image_input: LlavaImageInputs) -> torch.Tensor:
  149. assert self.vision_tower is not None
  150. image_features = self._process_image_pixels(image_input)
  151. return self.multi_modal_projector(image_features)
  152. def forward(
  153. self,
  154. input_ids: torch.Tensor,
  155. positions: torch.Tensor,
  156. kv_caches: List[torch.Tensor],
  157. attn_metadata: AttentionMetadata,
  158. **kwargs: object,
  159. ) -> SamplerOutput:
  160. """Run forward pass for LLaVA-1.5.
  161. One key thing to understand is the `input_ids` already accounts for the
  162. positions of the to-be-inserted image embeddings.
  163. Concretely, consider a text prompt:
  164. "<image>\nUSER: What's the content of the image?\nASSISTANT:".
  165. Tokenizer outputs:
  166. [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
  167. 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
  168. The to-be-inserted image has a size of 576 (24 * 24) along the context
  169. length dimension.
  170. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
  171. 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
  172. 9047, 13566, 29901].
  173. There will be 576 `32000` in the `input_ids`.
  174. (32000 is the token id for `<image>`.)
  175. This way, the `positions` and `attn_metadata` are consistent
  176. with the `input_ids`.
  177. Args:
  178. input_ids: Flattened (concatenated) input_ids corresponding to a
  179. batch.
  180. pixel_values: The pixels in each input image.
  181. """
  182. image_input = self._parse_and_validate_image_input(**kwargs)
  183. if image_input is not None:
  184. vision_embeddings = self._process_image_input(image_input)
  185. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  186. inputs_embeds = merge_vision_embeddings(
  187. input_ids, inputs_embeds, vision_embeddings,
  188. self.vlm_config.image_token_id)
  189. input_ids = None
  190. else:
  191. inputs_embeds = None
  192. hidden_states = self.language_model(input_ids,
  193. positions,
  194. kv_caches,
  195. attn_metadata,
  196. inputs_embeds=inputs_embeds)
  197. return hidden_states
  198. def compute_logits(self, hidden_states: torch.Tensor,
  199. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  200. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  201. sampling_metadata)
  202. return logits
  203. def sample(
  204. self,
  205. logits: torch.Tensor,
  206. sampling_metadata: SamplingMetadata,
  207. ) -> Optional[SamplerOutput]:
  208. next_tokens = self.sampler(logits, sampling_metadata)
  209. return next_tokens
  210. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  211. # only doing this for language model part for now.
  212. stacked_params_mapping = [
  213. # (param_name, shard_name, shard_id)
  214. ("qkv_proj", "q_proj", "q"),
  215. ("qkv_proj", "k_proj", "k"),
  216. ("qkv_proj", "v_proj", "v"),
  217. ("gate_up_proj", "gate_proj", 0),
  218. ("gate_up_proj", "up_proj", 1),
  219. ]
  220. params_dict = dict(self.named_parameters())
  221. for name, loaded_weight in weights:
  222. if "rotary_emb.inv_freq" in name:
  223. continue
  224. # post_layernorm is not needed in CLIPVisionModel
  225. if "vision_model.post_layernorm" in name:
  226. continue
  227. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  228. if key_to_modify in name:
  229. name = name.replace(key_to_modify, new_key)
  230. use_default_weight_loading = False
  231. if "vision" in name:
  232. if self.vision_tower is not None:
  233. # We only do sharding for language model and
  234. # not vision model for now.
  235. use_default_weight_loading = True
  236. else:
  237. for (param_name, weight_name,
  238. shard_id) in stacked_params_mapping:
  239. if weight_name not in name:
  240. continue
  241. param = params_dict[name.replace(weight_name, param_name)]
  242. weight_loader = param.weight_loader
  243. weight_loader(param, loaded_weight, shard_id)
  244. break
  245. else:
  246. use_default_weight_loading = True
  247. if use_default_weight_loading:
  248. param = params_dict[name]
  249. weight_loader = getattr(param, "weight_loader",
  250. default_weight_loader)
  251. weight_loader(param, loaded_weight)