llava_next.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
  2. import torch
  3. import torch.nn as nn
  4. from loguru import logger
  5. from PIL import Image
  6. from transformers import CLIPVisionConfig, LlavaNextConfig
  7. from transformers.models.llava_next.modeling_llava_next import (
  8. get_anyres_image_grid_shape, unpad_image)
  9. from typing_extensions import NotRequired
  10. from aphrodite.attention import AttentionMetadata
  11. from aphrodite.common.config import CacheConfig, VisionLanguageConfig
  12. from aphrodite.common.sequence import SamplerOutput
  13. from aphrodite.inputs import INPUT_REGISTRY, InputContext
  14. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  15. from aphrodite.modeling.layers.sampler import Sampler
  16. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  17. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  18. from aphrodite.modeling.models.clip import CLIPVisionModel
  19. from aphrodite.modeling.models.llama import LlamaModel
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  22. from aphrodite.quantization.base_config import QuantizationConfig
  23. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  24. get_clip_patch_grid_length)
  25. from .interfaces import SupportsVision
  26. from .llava import LlavaMultiModalProjector, merge_vision_embeddings
  27. _KEYS_TO_MODIFY_MAPPING = {
  28. "language_model.lm_head": "lm_head",
  29. "language_model.model": "language_model",
  30. }
  31. class LlavaNextImagePixelInputs(TypedDict):
  32. type: Literal["pixel_values"]
  33. data: torch.Tensor
  34. """Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
  35. image_sizes: NotRequired[torch.Tensor]
  36. """Shape: (batch_size, 2)"""
  37. LlavaNextImageInputs = LlavaNextImagePixelInputs
  38. def _get_llava_next_num_unpadded_features(
  39. height: int,
  40. width: int,
  41. npatches: int,
  42. num_patch_height: int,
  43. num_patch_width: int,
  44. ) -> Tuple[int, int]:
  45. # Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
  46. current_height = npatches * num_patch_height
  47. current_width = npatches * num_patch_width
  48. aspect_ratio: float = width / height
  49. current_aspect_ratio: float = current_width / current_height
  50. if aspect_ratio > current_aspect_ratio:
  51. new_height = (height * current_width) // width
  52. current_height = new_height
  53. else:
  54. new_width = (width * current_height) // height
  55. current_width = new_width
  56. unpadded_features = current_height * current_width
  57. newline_features = current_height
  58. return (unpadded_features, newline_features)
  59. def _get_llava_next_image_feature_size(
  60. hf_config: LlavaNextConfig,
  61. *,
  62. input_height: int,
  63. input_width: int,
  64. ) -> int:
  65. vision_config = hf_config.vision_config
  66. if isinstance(vision_config, CLIPVisionConfig):
  67. num_patches = get_clip_patch_grid_length(
  68. image_size=vision_config.image_size,
  69. patch_size=vision_config.patch_size,
  70. )
  71. base_feature_size = num_patches * num_patches
  72. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  73. image_size=(input_height, input_width),
  74. grid_pinpoints=hf_config.image_grid_pinpoints,
  75. patch_size=vision_config.image_size,
  76. )
  77. (
  78. unpadded_feature_size,
  79. newline_feature_size,
  80. ) = _get_llava_next_num_unpadded_features(input_height, input_width,
  81. num_patches,
  82. num_patch_height,
  83. num_patch_width)
  84. return unpadded_feature_size + newline_feature_size + base_feature_size
  85. msg = f"Unsupported vision config: {type(vision_config)}"
  86. raise NotImplementedError(msg)
  87. def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
  88. multimodal_config = ctx.get_multimodal_config()
  89. hf_config = ctx.get_hf_config(LlavaNextConfig)
  90. vision_config = hf_config.vision_config
  91. #TODO: change the logic for dummy data to support dynamic shape
  92. _, _, dummy_height, dummy_width = multimodal_config.image_input_shape
  93. image_feature_size = _get_llava_next_image_feature_size(
  94. hf_config, input_height=dummy_height, input_width=dummy_width)
  95. if isinstance(vision_config, CLIPVisionConfig):
  96. seq_data = dummy_seq_data_for_clip(
  97. vision_config,
  98. seq_len,
  99. image_token_id=hf_config.image_token_index,
  100. image_feature_size_override=image_feature_size,
  101. )
  102. mm_data = dummy_image_for_clip(
  103. vision_config,
  104. image_width_override=dummy_width,
  105. image_height_override=dummy_height,
  106. )
  107. return seq_data, mm_data
  108. msg = f"Unsupported vision config: {type(vision_config)}"
  109. raise NotImplementedError(msg)
  110. def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
  111. if isinstance(image, Image.Image):
  112. # Temporary patch before dynamic number of image tokens is supported
  113. _, _, h, w = ctx.get_multimodal_config().image_input_shape
  114. if (w, h) != (image.width, image.height):
  115. logger.warning(
  116. "Dynamic image shape is currently not supported. "
  117. "Resizing input image to (%d, %d).", w, h)
  118. image = image.resize((w, h))
  119. return MULTIMODAL_REGISTRY._get_plugin("image") \
  120. ._default_input_mapper(ctx, image)
  121. raise TypeError(f"Invalid type for 'image': {type(image)}")
  122. @MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
  123. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
  124. class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
  125. def __init__(self,
  126. config: LlavaNextConfig,
  127. vlm_config: VisionLanguageConfig,
  128. cache_config: Optional[CacheConfig] = None,
  129. quant_config: Optional[QuantizationConfig] = None) -> None:
  130. super().__init__()
  131. self.config = config
  132. self.vlm_config = vlm_config
  133. self.vision_tower = CLIPVisionModel(config=config.vision_config)
  134. self.multi_modal_projector = LlavaMultiModalProjector(
  135. vision_hidden_size=config.vision_config.hidden_size,
  136. text_hidden_size=config.text_config.hidden_size,
  137. projector_hidden_act=config.projector_hidden_act)
  138. self.quant_config = quant_config
  139. self.language_model = LlamaModel(config.text_config, cache_config,
  140. quant_config)
  141. self.unpadded_vocab_size = config.text_config.vocab_size
  142. self.lm_head = ParallelLMHead(
  143. self.unpadded_vocab_size,
  144. config.text_config.hidden_size,
  145. org_num_embeddings=self.language_model.org_vocab_size)
  146. logit_scale = getattr(config, "logit_scale", 1.0)
  147. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  148. config.vocab_size, logit_scale)
  149. self.sampler = Sampler()
  150. self.image_newline = nn.Parameter(
  151. torch.empty(config.text_config.hidden_size))
  152. def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
  153. _, num_channels, _, _ = self.vlm_config.image_input_shape
  154. # Note that this is different from that of vLLM vision_language_config
  155. # since the image is resized by the HuggingFace preprocessor
  156. height = width = self.config.vision_config.image_size
  157. if list(data.shape[2:]) != [num_channels, height, width]:
  158. raise ValueError(
  159. f"The expected image tensor shape is batch dimension plus "
  160. f"num_patches plus {[num_channels, height, width]}. "
  161. f"You supplied {data.shape}. "
  162. f"If you are using vLLM's entrypoint, make sure your "
  163. f"supplied image input is consistent with "
  164. f"image_input_shape in engine args.")
  165. return data
  166. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  167. if list(data.shape[1:]) != [2]:
  168. raise ValueError(
  169. f"The expected image sizes shape is batch dimension plus "
  170. f"{[2]}. You supplied {data.shape}.")
  171. return data
  172. def _parse_and_validate_image_input(
  173. self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
  174. pixel_values = kwargs.pop("pixel_values", None)
  175. image_sizes = kwargs.pop("image_sizes", None)
  176. if pixel_values is None or image_sizes is None:
  177. return None
  178. if not isinstance(pixel_values, torch.Tensor):
  179. raise ValueError("Incorrect type of pixel values. "
  180. f"Got type: {type(pixel_values)}")
  181. if not isinstance(image_sizes, torch.Tensor):
  182. raise ValueError("Incorrect type of image sizes. "
  183. f"Got type: {type(image_sizes)}")
  184. return LlavaNextImagePixelInputs(
  185. type="pixel_values",
  186. data=self._validate_image_pixels(pixel_values),
  187. image_sizes=self._validate_image_sizes(image_sizes),
  188. )
  189. def _select_image_features(self, image_features: torch.Tensor, *,
  190. strategy: str) -> torch.Tensor:
  191. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  192. if strategy == "default":
  193. return image_features[:, 1:]
  194. elif strategy == "full":
  195. return image_features
  196. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  197. def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
  198. pixel_values: torch.Tensor) -> torch.Tensor:
  199. # NOTE: we skip the step to select the vision feature layer since
  200. # this is already done inside the vision tower
  201. image_features = vision_tower(pixel_values,
  202. self.config.vision_feature_layer)
  203. return self._select_image_features(
  204. image_features,
  205. strategy=self.config.vision_feature_select_strategy,
  206. )
  207. def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
  208. patch_embeddings: torch.Tensor, *,
  209. strategy: str) -> torch.Tensor:
  210. # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
  211. if strategy == "flat":
  212. return patch_embeddings.flatten(0, 1)
  213. if strategy.startswith("spatial"):
  214. orig_width, orig_height = image_size
  215. height = width = self.config.vision_config.image_size \
  216. // self.config.vision_config.patch_size
  217. base_patch_embeds = patch_embeddings[0]
  218. if height * width != base_patch_embeds.shape[0]:
  219. raise ValueError(
  220. "The number of patches is not consistent with the "
  221. "image size.")
  222. if patch_embeddings.shape[0] > 1:
  223. other_patch_embeds = patch_embeddings[1:]
  224. # image_aspect_ratio == "anyres"
  225. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  226. (orig_width, orig_height),
  227. self.config.image_grid_pinpoints,
  228. self.config.vision_config.image_size,
  229. )
  230. other_patch_embeds = other_patch_embeds \
  231. .view(num_patch_width, num_patch_height, height, width, -1)
  232. if "unpad" in strategy:
  233. other_patch_embeds = other_patch_embeds \
  234. .permute(4, 0, 2, 1, 3).contiguous() \
  235. .flatten(1, 2).flatten(2, 3)
  236. other_patch_embeds = unpad_image(other_patch_embeds,
  237. image_size)
  238. other_patch_embeds = torch.cat((
  239. other_patch_embeds,
  240. self.image_newline[:, None, None] \
  241. .expand(*other_patch_embeds.shape[:-1], 1) \
  242. .to(other_patch_embeds.device),
  243. ), dim=-1)
  244. other_patch_embeds = other_patch_embeds \
  245. .flatten(1, 2).transpose(0, 1)
  246. else:
  247. other_patch_embeds = other_patch_embeds \
  248. .permute(0, 2, 1, 3, 4).contiguous() \
  249. .flatten(0, 3)
  250. merged_patch_embeddings = torch.cat(
  251. (base_patch_embeds, other_patch_embeds), dim=0)
  252. else:
  253. if "unpad" in strategy:
  254. merged_patch_embeddings = torch.cat(
  255. (base_patch_embeds,
  256. self.image_newline[None] \
  257. .to(base_patch_embeds.device)
  258. ), dim=0)
  259. else:
  260. merged_patch_embeddings = base_patch_embeds
  261. return merged_patch_embeddings
  262. raise ValueError(f"Unexpected patch merge strategy: {strategy}")
  263. def _process_image_pixels(
  264. self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
  265. assert self.vision_tower is not None
  266. pixel_values = inputs["data"]
  267. b, num_patches, c, h, w = pixel_values.shape
  268. stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
  269. stacked_image_features = self._image_pixels_to_features(
  270. self.vision_tower, stacked_pixel_values)
  271. return stacked_image_features.view(b, num_patches,
  272. *stacked_image_features.shape[-2:])
  273. def _process_image_input(
  274. self, image_input: LlavaNextImageInputs) -> torch.Tensor:
  275. assert self.vision_tower is not None
  276. image_features = self._process_image_pixels(image_input)
  277. patch_embeddings = self.multi_modal_projector(image_features)
  278. image_sizes = image_input.get("image_sizes")
  279. if image_sizes is None:
  280. batch_size = image_input["data"].shape[0]
  281. vision_config = self.config.vision_config
  282. default_width = default_height = vision_config.image_size
  283. image_sizes = torch.as_tensor([[default_width, default_height]
  284. for _ in range(batch_size)])
  285. merged_patch_embeddings = [
  286. self._merge_image_patch_embeddings(image_sizes[i],
  287. patch_features,
  288. strategy="spatial_unpad")
  289. for i, patch_features in enumerate(patch_embeddings)
  290. ]
  291. return torch.stack(merged_patch_embeddings, dim=0)
  292. def forward(
  293. self,
  294. input_ids: torch.Tensor,
  295. positions: torch.Tensor,
  296. kv_caches: List[torch.Tensor],
  297. attn_metadata: AttentionMetadata,
  298. **kwargs: object,
  299. ) -> SamplerOutput:
  300. """Run forward pass for LlaVA-NeXT.
  301. One key thing to understand is the `input_ids` already accounts for the
  302. positions of the to-be-inserted image embeddings.
  303. Concretely, consider a text prompt:
  304. "<image>\nUSER: What's the content of the image?\nASSISTANT:".
  305. Tokenizer outputs:
  306. [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
  307. 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
  308. The to-be-inserted image has a size of 576 (24 * 24) along the context
  309. length dimension.
  310. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
  311. 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
  312. 9047, 13566, 29901].
  313. There will be 576 `32000` in the `input_ids`.
  314. (32000 is the token id for `<image>`.)
  315. This way, the `positions` and `attn_metadata` are consistent
  316. with the `input_ids`.
  317. Args:
  318. input_ids: Flattened (concatenated) input_ids corresponding to a
  319. batch.
  320. pixel_values: The pixels in each grid patch for each input image.
  321. Expects a batch with shape `[1, num_patches, 3, 336, 336]`.
  322. image_sizes: The original `(width, height)` for each input image.
  323. Expects a batch with shape `[1, 2]`.
  324. See also:
  325. Each input maps to huggingface implementation, as follows:
  326. - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690
  327. - `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691
  328. """
  329. image_input = self._parse_and_validate_image_input(**kwargs)
  330. if image_input is not None:
  331. vision_embeddings = self._process_image_input(image_input)
  332. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  333. inputs_embeds = merge_vision_embeddings(
  334. input_ids, inputs_embeds, vision_embeddings,
  335. self.vlm_config.image_token_id)
  336. input_ids = None
  337. else:
  338. inputs_embeds = None
  339. hidden_states = self.language_model(input_ids,
  340. positions,
  341. kv_caches,
  342. attn_metadata,
  343. inputs_embeds=inputs_embeds)
  344. return hidden_states
  345. def compute_logits(self, hidden_states: torch.Tensor,
  346. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  347. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  348. sampling_metadata)
  349. return logits
  350. def sample(
  351. self,
  352. logits: torch.Tensor,
  353. sampling_metadata: SamplingMetadata,
  354. ) -> Optional[SamplerOutput]:
  355. next_tokens = self.sampler(logits, sampling_metadata)
  356. return next_tokens
  357. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  358. # only doing this for language model part for now.
  359. stacked_params_mapping = [
  360. # (param_name, shard_name, shard_id)
  361. ("qkv_proj", "q_proj", "q"),
  362. ("qkv_proj", "k_proj", "k"),
  363. ("qkv_proj", "v_proj", "v"),
  364. ("gate_up_proj", "gate_proj", 0),
  365. ("gate_up_proj", "up_proj", 1),
  366. ]
  367. params_dict = dict(self.named_parameters())
  368. for name, loaded_weight in weights:
  369. if "rotary_emb.inv_freq" in name:
  370. continue
  371. # post_layernorm is not needed in CLIPVisionModel
  372. if "vision_model.post_layernorm" in name:
  373. continue
  374. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  375. if key_to_modify in name:
  376. name = name.replace(key_to_modify, new_key)
  377. use_default_weight_loading = False
  378. if "vision" in name:
  379. if self.vision_tower is not None:
  380. # We only do sharding for language model and
  381. # not vision model for now.
  382. use_default_weight_loading = True
  383. else:
  384. for (param_name, weight_name,
  385. shard_id) in stacked_params_mapping:
  386. if weight_name not in name:
  387. continue
  388. param = params_dict[name.replace(weight_name, param_name)]
  389. weight_loader = param.weight_loader
  390. weight_loader(param, loaded_weight, shard_id)
  391. break
  392. else:
  393. use_default_weight_loading = True
  394. if use_default_weight_loading:
  395. param = params_dict[name]
  396. weight_loader = getattr(param, "weight_loader",
  397. default_weight_loader)
  398. weight_loader(param, loaded_weight)