llava.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  2. import torch
  3. import torch.nn as nn
  4. from transformers import CLIPVisionConfig, LlavaConfig
  5. from aphrodite.attention import AttentionMetadata
  6. from aphrodite.common.config import CacheConfig, MultiModalConfig
  7. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  8. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  9. from aphrodite.modeling.layers.activation import get_act_fn
  10. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  11. from aphrodite.modeling.layers.sampler import Sampler
  12. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  13. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  14. from aphrodite.modeling.models.clip import CLIPVisionModel
  15. from aphrodite.modeling.models.llama import LlamaModel
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.quantization.base_config import QuantizationConfig
  19. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  20. get_max_clip_image_tokens, input_processor_for_clip)
  21. from .interfaces import SupportsVision
  22. from .utils import merge_vision_embeddings
  23. _KEYS_TO_MODIFY_MAPPING = {
  24. "language_model.lm_head": "lm_head",
  25. "language_model.model": "language_model",
  26. }
  27. # TODO: Run benchmark and decide if TP.
  28. class LlavaMultiModalProjector(nn.Module):
  29. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  30. projector_hidden_act: str):
  31. super().__init__()
  32. self.linear_1 = nn.Linear(vision_hidden_size,
  33. text_hidden_size,
  34. bias=True)
  35. self.act = get_act_fn(projector_hidden_act)
  36. self.linear_2 = nn.Linear(text_hidden_size,
  37. text_hidden_size,
  38. bias=True)
  39. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  40. hidden_states = self.linear_1(image_features)
  41. hidden_states = self.act(hidden_states)
  42. hidden_states = self.linear_2(hidden_states)
  43. return hidden_states
  44. class LlavaImagePixelInputs(TypedDict):
  45. type: Literal["pixel_values"]
  46. data: torch.Tensor
  47. """Shape: `(batch_size, num_channels, height, width)`"""
  48. LlavaImageInputs = LlavaImagePixelInputs
  49. def get_max_llava_image_tokens(ctx: InputContext):
  50. hf_config = ctx.get_hf_config(LlavaConfig)
  51. vision_config = hf_config.vision_config
  52. if isinstance(vision_config, CLIPVisionConfig):
  53. return get_max_clip_image_tokens(vision_config)
  54. msg = f"Unsupported vision config: {type(vision_config)}"
  55. raise NotImplementedError(msg)
  56. def dummy_data_for_llava(ctx: InputContext, seq_len: int):
  57. hf_config = ctx.get_hf_config(LlavaConfig)
  58. vision_config = hf_config.vision_config
  59. if isinstance(vision_config, CLIPVisionConfig):
  60. seq_data = dummy_seq_data_for_clip(
  61. vision_config,
  62. seq_len,
  63. image_token_id=hf_config.image_token_index,
  64. )
  65. mm_data = dummy_image_for_clip(vision_config)
  66. return seq_data, mm_data
  67. msg = f"Unsupported vision config: {type(vision_config)}"
  68. raise NotImplementedError(msg)
  69. def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
  70. multi_modal_data = llm_inputs.get("multi_modal_data")
  71. if multi_modal_data is None or "image" not in multi_modal_data:
  72. return llm_inputs
  73. model_config = ctx.model_config
  74. hf_config = ctx.get_hf_config(LlavaConfig)
  75. vision_config = hf_config.vision_config
  76. if isinstance(vision_config, CLIPVisionConfig):
  77. return input_processor_for_clip(
  78. model_config,
  79. vision_config,
  80. llm_inputs,
  81. image_token_id=hf_config.image_token_index,
  82. )
  83. msg = f"Unsupported vision config: {type(vision_config)}"
  84. raise NotImplementedError(msg)
  85. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  86. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
  87. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
  88. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
  89. class LlavaForConditionalGeneration(nn.Module, SupportsVision):
  90. def __init__(self,
  91. config: LlavaConfig,
  92. multimodal_config: MultiModalConfig,
  93. cache_config: Optional[CacheConfig] = None,
  94. quant_config: Optional[QuantizationConfig] = None) -> None:
  95. super().__init__()
  96. self.config = config
  97. self.multimodal_config = multimodal_config
  98. # Initialize the vision tower only up to the required feature layer
  99. vision_feature_layer = config.vision_feature_layer
  100. if vision_feature_layer < 0:
  101. num_hidden_layers = config.vision_config.num_hidden_layers \
  102. + vision_feature_layer + 1
  103. else:
  104. num_hidden_layers = vision_feature_layer + 1
  105. # TODO: Optionally initializes this for supporting embeddings.
  106. self.vision_tower = CLIPVisionModel(
  107. config.vision_config, num_hidden_layers_override=num_hidden_layers)
  108. self.multi_modal_projector = LlavaMultiModalProjector(
  109. vision_hidden_size=config.vision_config.hidden_size,
  110. text_hidden_size=config.text_config.hidden_size,
  111. projector_hidden_act=config.projector_hidden_act)
  112. self.quant_config = quant_config
  113. self.language_model = LlamaModel(config.text_config, cache_config,
  114. quant_config)
  115. self.unpadded_vocab_size = config.text_config.vocab_size
  116. self.lm_head = ParallelLMHead(
  117. self.unpadded_vocab_size,
  118. config.text_config.hidden_size,
  119. org_num_embeddings=self.language_model.org_vocab_size,
  120. quant_config=quant_config)
  121. logit_scale = getattr(config, "logit_scale", 1.0)
  122. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  123. config.text_config.vocab_size,
  124. logit_scale)
  125. self.sampler = Sampler()
  126. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  127. h = w = self.config.vision_config.image_size
  128. expected_dims = (3, h, w)
  129. actual_dims = tuple(data.shape[1:])
  130. if actual_dims != expected_dims:
  131. expected_expr = ("batch_size", *map(str, expected_dims))
  132. raise ValueError(
  133. f"The expected shape of pixel values is {expected_expr}. "
  134. f"You supplied {tuple(data.shape)}.")
  135. return data
  136. def _parse_and_validate_image_input(
  137. self, **kwargs: object) -> Optional[LlavaImageInputs]:
  138. pixel_values = kwargs.pop("pixel_values", None)
  139. if pixel_values is None:
  140. return None
  141. if not isinstance(pixel_values, torch.Tensor):
  142. raise ValueError("Incorrect type of pixel values. "
  143. f"Got type: {type(pixel_values)}")
  144. return LlavaImagePixelInputs(
  145. type="pixel_values",
  146. data=self._validate_pixel_values(pixel_values),
  147. )
  148. def _select_image_features(self, image_features: torch.Tensor, *,
  149. strategy: str) -> torch.Tensor:
  150. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  151. if strategy == "default":
  152. return image_features[:, 1:]
  153. elif strategy == "full":
  154. return image_features
  155. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  156. def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
  157. pixel_values: torch.Tensor) -> torch.Tensor:
  158. # NOTE: we skip the step to select the vision feature layer since
  159. # this is already done inside the vision tower
  160. image_features = vision_tower(pixel_values)
  161. return self._select_image_features(
  162. image_features,
  163. strategy=self.config.vision_feature_select_strategy,
  164. )
  165. def _process_image_pixels(self,
  166. inputs: LlavaImagePixelInputs) -> torch.Tensor:
  167. assert self.vision_tower is not None
  168. pixel_values = inputs["data"]
  169. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  170. def _process_image_input(self,
  171. image_input: LlavaImageInputs) -> torch.Tensor:
  172. assert self.vision_tower is not None
  173. image_features = self._process_image_pixels(image_input)
  174. return self.multi_modal_projector(image_features)
  175. def forward(
  176. self,
  177. input_ids: torch.Tensor,
  178. positions: torch.Tensor,
  179. kv_caches: List[torch.Tensor],
  180. attn_metadata: AttentionMetadata,
  181. intermediate_tensors: Optional[IntermediateTensors] = None,
  182. **kwargs: object,
  183. ) -> SamplerOutput:
  184. """Run forward pass for LLaVA-1.5.
  185. One key thing to understand is the `input_ids` already accounts for the
  186. positions of the to-be-inserted image embeddings.
  187. Concretely, consider a text prompt:
  188. `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
  189. Tokenizer outputs:
  190. `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
  191. 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
  192. To reserve space in KV cache, we have to insert placeholder tokens
  193. before they are inputted to the model, so the input processor prepends
  194. additional image tokens (denoted as `32000`), resulting in:
  195. `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
  196. 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
  197. 29901]`.
  198. We insert 575 tokens so that including the original image token in the
  199. input, there are a total of 576 (24 * 24) image tokens, which
  200. corresponds to the number of image tokens inputted to the language
  201. model, i.e. the number of image tokens outputted by the visual encoder.
  202. This way, the `positions` and `attn_metadata` are consistent
  203. with the `input_ids`.
  204. Args:
  205. input_ids: Flattened (concatenated) input_ids corresponding to a
  206. batch.
  207. pixel_values: The pixels in each input image.
  208. See also:
  209. :class:`LlavaImageInputs`
  210. """
  211. image_input = self._parse_and_validate_image_input(**kwargs)
  212. if image_input is not None:
  213. vision_embeddings = self._process_image_input(image_input)
  214. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  215. inputs_embeds = merge_vision_embeddings(
  216. input_ids, inputs_embeds, vision_embeddings,
  217. self.config.image_token_index)
  218. input_ids = None
  219. else:
  220. inputs_embeds = None
  221. hidden_states = self.language_model(input_ids,
  222. positions,
  223. kv_caches,
  224. attn_metadata,
  225. None,
  226. inputs_embeds=inputs_embeds)
  227. return hidden_states
  228. def compute_logits(self, hidden_states: torch.Tensor,
  229. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  230. logits = self.logits_processor(self.lm_head, hidden_states,
  231. sampling_metadata)
  232. return logits
  233. def sample(
  234. self,
  235. logits: torch.Tensor,
  236. sampling_metadata: SamplingMetadata,
  237. ) -> Optional[SamplerOutput]:
  238. next_tokens = self.sampler(logits, sampling_metadata)
  239. return next_tokens
  240. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  241. # only doing this for language model part for now.
  242. stacked_params_mapping = [
  243. # (param_name, shard_name, shard_id)
  244. ("qkv_proj", "q_proj", "q"),
  245. ("qkv_proj", "k_proj", "k"),
  246. ("qkv_proj", "v_proj", "v"),
  247. ("gate_up_proj", "gate_proj", 0),
  248. ("gate_up_proj", "up_proj", 1),
  249. ]
  250. params_dict = dict(self.named_parameters())
  251. for name, loaded_weight in weights:
  252. if "rotary_emb.inv_freq" in name:
  253. continue
  254. # post_layernorm is not needed in CLIPVisionModel
  255. if "vision_model.post_layernorm" in name:
  256. continue
  257. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  258. if key_to_modify in name:
  259. name = name.replace(key_to_modify, new_key)
  260. use_default_weight_loading = False
  261. if "vision" in name:
  262. if self.vision_tower is not None:
  263. # We only do sharding for language model and
  264. # not vision model for now.
  265. use_default_weight_loading = True
  266. else:
  267. for (param_name, weight_name,
  268. shard_id) in stacked_params_mapping:
  269. if weight_name not in name:
  270. continue
  271. param = params_dict[name.replace(weight_name, param_name)]
  272. weight_loader = param.weight_loader
  273. weight_loader(param, loaded_weight, shard_id)
  274. break
  275. else:
  276. use_default_weight_loading = True
  277. if use_default_weight_loading and name in params_dict:
  278. param = params_dict[name]
  279. weight_loader = getattr(param, "weight_loader",
  280. default_weight_loader)
  281. weight_loader(param, loaded_weight)