llava.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  2. import torch
  3. from torch import nn
  4. # TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
  5. # transformers' impl.
  6. from transformers import CLIPVisionModel, LlavaConfig
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, VisionLanguageConfig
  9. from aphrodite.common.sequence import SamplerOutput
  10. from aphrodite.modeling.layers.activation import get_act_fn
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.sampler import Sampler
  13. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  14. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  15. from aphrodite.modeling.models.llama import LlamaModel
  16. from aphrodite.modeling.models.vlm_base import VisionLanguageModelBase
  17. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  18. from aphrodite.quantization.base_config import QuantizationConfig
  19. _KEYS_TO_MODIFY_MAPPING = {
  20. "language_model.lm_head": "lm_head",
  21. "language_model.model": "language_model",
  22. }
  23. # TODO(xwjiang): Run benchmark and decide if TP.
  24. class LlavaMultiModalProjector(nn.Module):
  25. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  26. projector_hidden_act: str):
  27. super().__init__()
  28. self.linear_1 = nn.Linear(vision_hidden_size,
  29. text_hidden_size,
  30. bias=True)
  31. self.act = get_act_fn(projector_hidden_act)
  32. self.linear_2 = nn.Linear(text_hidden_size,
  33. text_hidden_size,
  34. bias=True)
  35. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  36. hidden_states = self.linear_1(image_features)
  37. hidden_states = self.act(hidden_states)
  38. hidden_states = self.linear_2(hidden_states)
  39. return hidden_states
  40. def _merge_vision_embeddings(input_ids: torch.Tensor,
  41. inputs_embeds: torch.Tensor,
  42. vision_embeddings: torch.Tensor,
  43. image_token_id: int) -> torch.Tensor:
  44. """In place merges in vision_embeddings with inputs_embeds."""
  45. mask = (input_ids == image_token_id)
  46. image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
  47. if mask.sum() != image_feature_size:
  48. raise ValueError(f"image_feature_size should be {image_feature_size}, "
  49. f"but found: {mask.sum()}")
  50. inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
  51. vision_embeddings.shape[-1])
  52. return inputs_embeds
  53. class LlavaImagePixelInputs(TypedDict):
  54. type: Literal["pixel_values"]
  55. data: torch.Tensor
  56. """Shape: (batch_size, num_channels, height, width)"""
  57. class LlavaImageFeatureInputs(TypedDict):
  58. type: Literal["image_features"]
  59. data: torch.Tensor
  60. """Shape: (batch_size, image_feature_size, hidden_size)"""
  61. LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
  62. class LlavaForConditionalGeneration(VisionLanguageModelBase):
  63. def __init__(self,
  64. config: LlavaConfig,
  65. vision_language_config: VisionLanguageConfig,
  66. cache_config: Optional[CacheConfig] = None,
  67. quant_config: Optional[QuantizationConfig] = None) -> None:
  68. super().__init__(vision_language_config)
  69. self.config = config
  70. if self.vision_language_config.image_input_type == (
  71. VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
  72. self.vision_tower = CLIPVisionModel(config.vision_config)
  73. else:
  74. self.vision_tower = None
  75. self.multi_modal_projector = LlavaMultiModalProjector(
  76. vision_hidden_size=config.vision_config.hidden_size,
  77. text_hidden_size=config.text_config.hidden_size,
  78. projector_hidden_act=config.projector_hidden_act)
  79. self.quant_config = quant_config
  80. self.language_model = LlamaModel(config.text_config, cache_config,
  81. quant_config)
  82. self.unpadded_vocab_size = config.text_config.vocab_size
  83. self.lm_head = ParallelLMHead(
  84. self.unpadded_vocab_size,
  85. config.text_config.hidden_size,
  86. org_num_embeddings=self.language_model.org_vocab_size)
  87. logit_scale = getattr(config, "logit_scale", 1.0)
  88. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  89. config.vocab_size, logit_scale)
  90. self.sampler = Sampler()
  91. def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
  92. if list(data.shape[1:]) != list(
  93. self.vision_language_config.image_input_shape[1:]):
  94. raise ValueError(
  95. f"The expected image tensor shape is batch dimension plus "
  96. f"{self.vision_language_config.image_input_shape[1:]}. "
  97. f"You supplied {data.shape}. "
  98. f"If you are using vLLM's entrypoint, make sure your "
  99. f"supplied image input is consistent with "
  100. f"image_input_shape in engine args.")
  101. return data
  102. def _parse_and_validate_image_input(
  103. self, data: object) -> Optional[LlavaImageInputs]:
  104. expected_input_type = self.vision_language_config.image_input_type
  105. ImageInputType = VisionLanguageConfig.ImageInputType
  106. if data is None:
  107. return None
  108. if expected_input_type == ImageInputType.PIXEL_VALUES:
  109. if not isinstance(data, torch.Tensor):
  110. raise TypeError("Image pixel vector should be a tensor, "
  111. f"but received type: {type(data)}")
  112. return LlavaImagePixelInputs(
  113. type="pixel_values",
  114. data=self._validate_image_data(data),
  115. )
  116. elif expected_input_type == ImageInputType.IMAGE_FEATURES:
  117. if not isinstance(data, torch.Tensor):
  118. raise TypeError("Image feature vector should be a tensor, "
  119. f"but received type: {type(data)}")
  120. return LlavaImageFeatureInputs(
  121. type="image_features",
  122. data=self._validate_image_data(data),
  123. )
  124. return None
  125. def _select_image_features(self, image_features: torch.Tensor, *,
  126. strategy: str) -> torch.Tensor:
  127. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  128. if strategy == "default":
  129. return image_features[:, 1:]
  130. elif strategy == "full":
  131. return image_features
  132. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  133. def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
  134. pixel_values: torch.Tensor) -> torch.Tensor:
  135. # TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
  136. image_outputs = vision_tower(pixel_values.to(vision_tower.device),
  137. output_hidden_states=True)
  138. image_features = image_outputs.hidden_states[
  139. self.config.vision_feature_layer]
  140. return self._select_image_features(
  141. image_features,
  142. strategy=self.config.vision_feature_select_strategy,
  143. )
  144. def _process_image_pixels(self,
  145. inputs: LlavaImagePixelInputs) -> torch.Tensor:
  146. assert self.vision_tower is not None
  147. pixel_values = inputs["data"]
  148. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  149. def _process_image_input(self,
  150. image_input: LlavaImageInputs) -> torch.Tensor:
  151. if image_input["type"] == "pixel_values":
  152. assert self.vision_tower is not None
  153. image_features = self._process_image_pixels(image_input)
  154. else:
  155. image_features = image_input["data"]
  156. return self.multi_modal_projector(image_features)
  157. def forward(self,
  158. input_ids: torch.Tensor,
  159. positions: torch.Tensor,
  160. kv_caches: List[torch.Tensor],
  161. attn_metadata: AttentionMetadata,
  162. image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
  163. """Run forward pass for Llava 1.5.
  164. One key thing to understand is the `input_ids` already accounts for the
  165. positions of the to-be-inserted image embeddings.
  166. Concretely, consider a text prompt:
  167. "<image>\nUSER: What's the content of the image?\nASSISTANT:".
  168. Tokenizer outputs:
  169. [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
  170. 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
  171. The to-be-inserted image has a size of 576 (24 * 24) along the context
  172. length dimension.
  173. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
  174. 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
  175. 9047, 13566, 29901].
  176. There will be 576 `32000` in the `input_ids`.
  177. (32000 is the token id for `<image>`.)
  178. This way, the `positions` and `attn_metadata` are consistent
  179. with the `input_ids`.
  180. The model takes two types of image inputs:
  181. PIXEL_VALUES and IMAGE_FEATURES.
  182. The following shows how each maps to huggingface implementation.
  183. PIXEL_VALUES:
  184. - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
  185. IMAGE_FEATURES:
  186. - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
  187. before going through the multi modal projector.
  188. Args:
  189. input_ids: Flattened (concatenated) input_ids corresponding to a
  190. batch.
  191. image_input: A batch of image inputs.
  192. For PIXEL_VALUES, expecting [1, 3, 336, 336].
  193. For IMAGE_FEATURES, expecting [1, 576, 1024].
  194. """
  195. parsed_image_input = self._parse_and_validate_image_input(image_input)
  196. if parsed_image_input is not None:
  197. vision_embeddings = self._process_image_input(parsed_image_input)
  198. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  199. inputs_embeds = _merge_vision_embeddings(
  200. input_ids, inputs_embeds, vision_embeddings,
  201. self.vision_language_config.image_token_id)
  202. input_ids = None
  203. else:
  204. inputs_embeds = None
  205. hidden_states = self.language_model(input_ids,
  206. positions,
  207. kv_caches,
  208. attn_metadata,
  209. inputs_embeds=inputs_embeds)
  210. return hidden_states
  211. def compute_logits(self, hidden_states: torch.Tensor,
  212. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  213. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  214. sampling_metadata)
  215. return logits
  216. def sample(
  217. self,
  218. logits: torch.Tensor,
  219. sampling_metadata: SamplingMetadata,
  220. ) -> Optional[SamplerOutput]:
  221. next_tokens = self.sampler(logits, sampling_metadata)
  222. return next_tokens
  223. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  224. # only doing this for language model part for now.
  225. stacked_params_mapping = [
  226. # (param_name, shard_name, shard_id)
  227. ("qkv_proj", "q_proj", "q"),
  228. ("qkv_proj", "k_proj", "k"),
  229. ("qkv_proj", "v_proj", "v"),
  230. ("gate_up_proj", "gate_proj", 0),
  231. ("gate_up_proj", "up_proj", 1),
  232. ]
  233. params_dict = dict(self.named_parameters())
  234. for name, loaded_weight in weights:
  235. if "rotary_emb.inv_freq" in name:
  236. continue
  237. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  238. if key_to_modify in name:
  239. name = name.replace(key_to_modify, new_key)
  240. use_default_weight_loading = False
  241. if "vision" in name:
  242. if self.vision_tower is not None:
  243. # We only do sharding for language model and
  244. # not vision model for now.
  245. use_default_weight_loading = True
  246. else:
  247. for (param_name, weight_name,
  248. shard_id) in stacked_params_mapping:
  249. if weight_name not in name:
  250. continue
  251. param = params_dict[name.replace(weight_name, param_name)]
  252. weight_loader = param.weight_loader
  253. weight_loader(param, loaded_weight, shard_id)
  254. break
  255. else:
  256. use_default_weight_loading = True
  257. if use_default_weight_loading:
  258. param = params_dict[name]
  259. weight_loader = getattr(param, "weight_loader",
  260. default_weight_loader)
  261. weight_loader(param, loaded_weight)