llava.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  2. import torch
  3. import torch.nn as nn
  4. from transformers import CLIPVisionConfig, LlavaConfig
  5. from aphrodite.attention import AttentionMetadata
  6. from aphrodite.common.config import CacheConfig, VisionLanguageConfig
  7. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  8. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  9. from aphrodite.modeling.layers.activation import get_act_fn
  10. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  11. from aphrodite.modeling.layers.sampler import Sampler
  12. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  13. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  14. from aphrodite.modeling.models.clip import CLIPVisionModel
  15. from aphrodite.modeling.models.llama import LlamaModel
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.quantization.base_config import QuantizationConfig
  19. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  20. input_processor_for_clip)
  21. from .interfaces import SupportsVision
  22. from .utils import merge_vision_embeddings
  23. _KEYS_TO_MODIFY_MAPPING = {
  24. "language_model.lm_head": "lm_head",
  25. "language_model.model": "language_model",
  26. }
  27. # TODO: Run benchmark and decide if TP.
  28. class LlavaMultiModalProjector(nn.Module):
  29. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  30. projector_hidden_act: str):
  31. super().__init__()
  32. self.linear_1 = nn.Linear(vision_hidden_size,
  33. text_hidden_size,
  34. bias=True)
  35. self.act = get_act_fn(projector_hidden_act)
  36. self.linear_2 = nn.Linear(text_hidden_size,
  37. text_hidden_size,
  38. bias=True)
  39. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  40. hidden_states = self.linear_1(image_features)
  41. hidden_states = self.act(hidden_states)
  42. hidden_states = self.linear_2(hidden_states)
  43. return hidden_states
  44. class LlavaImagePixelInputs(TypedDict):
  45. type: Literal["pixel_values"]
  46. data: torch.Tensor
  47. """Shape: `(batch_size, num_channels, height, width)`"""
  48. LlavaImageInputs = LlavaImagePixelInputs
  49. def dummy_data_for_llava(ctx: InputContext, seq_len: int):
  50. hf_config = ctx.get_hf_config(LlavaConfig)
  51. vision_config = hf_config.vision_config
  52. if isinstance(vision_config, CLIPVisionConfig):
  53. seq_data = dummy_seq_data_for_clip(
  54. vision_config,
  55. seq_len,
  56. image_token_id=hf_config.image_token_index,
  57. )
  58. mm_data = dummy_image_for_clip(vision_config)
  59. return seq_data, mm_data
  60. msg = f"Unsupported vision config: {type(vision_config)}"
  61. raise NotImplementedError(msg)
  62. def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
  63. multi_modal_data = llm_inputs.get("multi_modal_data")
  64. if multi_modal_data is None or "image" not in multi_modal_data:
  65. return llm_inputs
  66. model_config = ctx.model_config
  67. hf_config = ctx.get_hf_config(LlavaConfig)
  68. vision_config = hf_config.vision_config
  69. if isinstance(vision_config, CLIPVisionConfig):
  70. return input_processor_for_clip(
  71. model_config,
  72. vision_config,
  73. llm_inputs,
  74. image_token_id=hf_config.image_token_index,
  75. )
  76. msg = f"Unsupported vision config: {type(vision_config)}"
  77. raise NotImplementedError(msg)
  78. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  79. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
  80. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
  81. class LlavaForConditionalGeneration(nn.Module, SupportsVision):
  82. def __init__(self,
  83. config: LlavaConfig,
  84. vlm_config: VisionLanguageConfig,
  85. cache_config: Optional[CacheConfig] = None,
  86. quant_config: Optional[QuantizationConfig] = None) -> None:
  87. super().__init__()
  88. self.config = config
  89. self.vlm_config = vlm_config
  90. # TODO: Optionally initializes this for supporting embeddings.
  91. self.vision_tower = CLIPVisionModel(config.vision_config)
  92. self.multi_modal_projector = LlavaMultiModalProjector(
  93. vision_hidden_size=config.vision_config.hidden_size,
  94. text_hidden_size=config.text_config.hidden_size,
  95. projector_hidden_act=config.projector_hidden_act)
  96. self.quant_config = quant_config
  97. self.language_model = LlamaModel(config.text_config, cache_config,
  98. quant_config)
  99. self.unpadded_vocab_size = config.text_config.vocab_size
  100. self.lm_head = ParallelLMHead(
  101. self.unpadded_vocab_size,
  102. config.text_config.hidden_size,
  103. org_num_embeddings=self.language_model.org_vocab_size,
  104. quant_config=quant_config)
  105. logit_scale = getattr(config, "logit_scale", 1.0)
  106. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  107. config.vocab_size, logit_scale)
  108. self.sampler = Sampler()
  109. def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
  110. if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
  111. raise ValueError(
  112. f"The expected image tensor shape is batch dimension plus "
  113. f"{self.vlm_config.image_input_shape[1:]}. "
  114. f"You supplied {data.shape}. "
  115. f"If you are using vLLM's entrypoint, make sure your "
  116. f"supplied image input is consistent with "
  117. f"image_input_shape in engine args.")
  118. return data
  119. def _parse_and_validate_image_input(
  120. self, **kwargs: object) -> Optional[LlavaImageInputs]:
  121. pixel_values = kwargs.pop("pixel_values", None)
  122. if pixel_values is None:
  123. return None
  124. if not isinstance(pixel_values, torch.Tensor):
  125. raise ValueError("Incorrect type of pixel values. "
  126. f"Got type: {type(pixel_values)}")
  127. return LlavaImagePixelInputs(
  128. type="pixel_values",
  129. data=self._validate_image_data(pixel_values),
  130. )
  131. def _select_image_features(self, image_features: torch.Tensor, *,
  132. strategy: str) -> torch.Tensor:
  133. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  134. if strategy == "default":
  135. return image_features[:, 1:]
  136. elif strategy == "full":
  137. return image_features
  138. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  139. def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
  140. pixel_values: torch.Tensor) -> torch.Tensor:
  141. # NOTE: we skip the step to select the vision feature layer since
  142. # this is already done inside the vision tower
  143. image_features = vision_tower(pixel_values,
  144. self.config.vision_feature_layer)
  145. return self._select_image_features(
  146. image_features,
  147. strategy=self.config.vision_feature_select_strategy,
  148. )
  149. def _process_image_pixels(self,
  150. inputs: LlavaImagePixelInputs) -> torch.Tensor:
  151. assert self.vision_tower is not None
  152. pixel_values = inputs["data"]
  153. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  154. def _process_image_input(self,
  155. image_input: LlavaImageInputs) -> torch.Tensor:
  156. assert self.vision_tower is not None
  157. image_features = self._process_image_pixels(image_input)
  158. return self.multi_modal_projector(image_features)
  159. def forward(
  160. self,
  161. input_ids: torch.Tensor,
  162. positions: torch.Tensor,
  163. kv_caches: List[torch.Tensor],
  164. attn_metadata: AttentionMetadata,
  165. intermediate_tensors: Optional[IntermediateTensors] = None,
  166. **kwargs: object,
  167. ) -> SamplerOutput:
  168. """Run forward pass for LLaVA-1.5.
  169. One key thing to understand is the `input_ids` already accounts for the
  170. positions of the to-be-inserted image embeddings.
  171. Concretely, consider a text prompt:
  172. "<image>\nUSER: What's the content of the image?\nASSISTANT:".
  173. Tokenizer outputs:
  174. [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
  175. 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
  176. The to-be-inserted image has a size of 576 (24 * 24) along the context
  177. length dimension.
  178. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
  179. 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
  180. 9047, 13566, 29901].
  181. There will be 576 `32000` in the `input_ids`.
  182. (32000 is the token id for `<image>`.)
  183. This way, the `positions` and `attn_metadata` are consistent
  184. with the `input_ids`.
  185. Args:
  186. input_ids: Flattened (concatenated) input_ids corresponding to a
  187. batch.
  188. pixel_values: The pixels in each input image.
  189. """
  190. image_input = self._parse_and_validate_image_input(**kwargs)
  191. if image_input is not None:
  192. vision_embeddings = self._process_image_input(image_input)
  193. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  194. inputs_embeds = merge_vision_embeddings(
  195. input_ids, inputs_embeds, vision_embeddings,
  196. self.vlm_config.image_token_id)
  197. input_ids = None
  198. else:
  199. inputs_embeds = None
  200. hidden_states = self.language_model(input_ids,
  201. positions,
  202. kv_caches,
  203. attn_metadata,
  204. None,
  205. inputs_embeds=inputs_embeds)
  206. return hidden_states
  207. def compute_logits(self, hidden_states: torch.Tensor,
  208. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  209. logits = self.logits_processor(self.lm_head, hidden_states,
  210. sampling_metadata)
  211. return logits
  212. def sample(
  213. self,
  214. logits: torch.Tensor,
  215. sampling_metadata: SamplingMetadata,
  216. ) -> Optional[SamplerOutput]:
  217. next_tokens = self.sampler(logits, sampling_metadata)
  218. return next_tokens
  219. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  220. # only doing this for language model part for now.
  221. stacked_params_mapping = [
  222. # (param_name, shard_name, shard_id)
  223. ("qkv_proj", "q_proj", "q"),
  224. ("qkv_proj", "k_proj", "k"),
  225. ("qkv_proj", "v_proj", "v"),
  226. ("gate_up_proj", "gate_proj", 0),
  227. ("gate_up_proj", "up_proj", 1),
  228. ]
  229. params_dict = dict(self.named_parameters())
  230. for name, loaded_weight in weights:
  231. if "rotary_emb.inv_freq" in name:
  232. continue
  233. # post_layernorm is not needed in CLIPVisionModel
  234. if "vision_model.post_layernorm" in name:
  235. continue
  236. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  237. if key_to_modify in name:
  238. name = name.replace(key_to_modify, new_key)
  239. use_default_weight_loading = False
  240. if "vision" in name:
  241. if self.vision_tower is not None:
  242. # We only do sharding for language model and
  243. # not vision model for now.
  244. use_default_weight_loading = True
  245. else:
  246. for (param_name, weight_name,
  247. shard_id) in stacked_params_mapping:
  248. if weight_name not in name:
  249. continue
  250. param = params_dict[name.replace(weight_name, param_name)]
  251. weight_loader = param.weight_loader
  252. weight_loader(param, loaded_weight, shard_id)
  253. break
  254. else:
  255. use_default_weight_loading = True
  256. if use_default_weight_loading:
  257. param = params_dict[name]
  258. weight_loader = getattr(param, "weight_loader",
  259. default_weight_loader)
  260. weight_loader(param, loaded_weight)