llava_next.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  2. import torch
  3. import torch.nn as nn
  4. from PIL import Image
  5. from transformers import CLIPVisionConfig, LlavaNextConfig
  6. from transformers.models.llava_next.modeling_llava_next import (
  7. get_anyres_image_grid_shape, unpad_image)
  8. from typing_extensions import NotRequired
  9. from aphrodite.attention import AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, VisionLanguageConfig
  11. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  12. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  13. from aphrodite.quantization.base_config import (QuantizationConfig)
  14. from aphrodite.modeling.layers.sampler import Sampler
  15. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  16. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  17. from aphrodite.modeling.models.clip import CLIPVisionModel
  18. from aphrodite.modeling.models.llama import LlamaModel
  19. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  20. from aphrodite.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
  21. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  22. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  23. get_clip_patch_grid_length, input_processor_for_clip)
  24. from .interfaces import SupportsVision
  25. from .llava import LlavaMultiModalProjector
  26. from .utils import merge_vision_embeddings
  27. _KEYS_TO_MODIFY_MAPPING = {
  28. "language_model.lm_head": "lm_head",
  29. "language_model.model": "language_model",
  30. }
  31. class LlavaNextImagePixelInputs(TypedDict):
  32. type: Literal["pixel_values"]
  33. data: BatchedTensors
  34. """
  35. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  36. Note that `num_patches` may be different for each batch.
  37. """
  38. image_sizes: NotRequired[torch.Tensor]
  39. """
  40. Shape: `(batch_size, 2)`
  41. This should be in `(height, width)` format.
  42. """
  43. LlavaNextImageInputs = LlavaNextImagePixelInputs
  44. # Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
  45. # NOTE: new_height and new_width are further incremented to properly invert the
  46. # floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
  47. def _get_llava_next_num_unpadded_features(
  48. height: int,
  49. width: int,
  50. npatches: int,
  51. num_patch_height: int,
  52. num_patch_width: int,
  53. ) -> Tuple[int, int]:
  54. current_height = npatches * num_patch_height
  55. current_width = npatches * num_patch_width
  56. aspect_ratio: float = width / height
  57. current_aspect_ratio: float = current_width / current_height
  58. if aspect_ratio > current_aspect_ratio:
  59. new_height = (height * current_width) // width
  60. if new_height % 2 == 1:
  61. new_height += 1
  62. current_height = new_height
  63. else:
  64. new_width = (width * current_height) // height
  65. if new_width % 2 == 1:
  66. new_width += 1
  67. current_width = new_width
  68. unpadded_features = current_height * current_width
  69. newline_features = current_height
  70. return (unpadded_features, newline_features)
  71. # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
  72. def get_llava_next_image_feature_size(
  73. hf_config: LlavaNextConfig,
  74. *,
  75. input_height: int,
  76. input_width: int,
  77. ) -> int:
  78. vision_config = hf_config.vision_config
  79. if isinstance(vision_config, CLIPVisionConfig):
  80. num_patches = get_clip_patch_grid_length(
  81. image_size=vision_config.image_size,
  82. patch_size=vision_config.patch_size,
  83. )
  84. base_feature_size = num_patches * num_patches
  85. # Note: We follow the "wrong" width/height order
  86. # [ref: PR huggingface/transformers#31588]
  87. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  88. image_size=(input_height, input_width),
  89. grid_pinpoints=hf_config.image_grid_pinpoints,
  90. patch_size=vision_config.image_size,
  91. )
  92. (
  93. unpadded_feature_size,
  94. newline_feature_size,
  95. ) = _get_llava_next_num_unpadded_features(input_height, input_width,
  96. num_patches,
  97. num_patch_height,
  98. num_patch_width)
  99. return unpadded_feature_size + newline_feature_size + base_feature_size
  100. msg = f"Unsupported vision config: {type(vision_config)}"
  101. raise NotImplementedError(msg)
  102. def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
  103. hf_config = ctx.get_hf_config(LlavaNextConfig)
  104. vision_config = hf_config.vision_config
  105. # Result in the max possible feature size (2x2 grid of 336x336px tiles)
  106. dummy_height = dummy_width = 448
  107. image_feature_size = get_llava_next_image_feature_size(
  108. hf_config,
  109. input_height=dummy_height,
  110. input_width=dummy_width,
  111. )
  112. if isinstance(vision_config, CLIPVisionConfig):
  113. seq_data = dummy_seq_data_for_clip(
  114. vision_config,
  115. seq_len,
  116. image_token_id=hf_config.image_token_index,
  117. image_feature_size_override=image_feature_size,
  118. )
  119. mm_data = dummy_image_for_clip(
  120. vision_config,
  121. image_width_override=dummy_width,
  122. image_height_override=dummy_height,
  123. )
  124. return seq_data, mm_data
  125. msg = f"Unsupported vision config: {type(vision_config)}"
  126. raise NotImplementedError(msg)
  127. def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
  128. multi_modal_data = llm_inputs.get("multi_modal_data")
  129. if multi_modal_data is None or "image" not in multi_modal_data:
  130. return llm_inputs
  131. model_config = ctx.model_config
  132. hf_config = ctx.get_hf_config(LlavaNextConfig)
  133. vision_config = hf_config.vision_config
  134. image_data = multi_modal_data["image"]
  135. if isinstance(image_data, Image.Image):
  136. width, height = image_data.size
  137. image_feature_size = get_llava_next_image_feature_size(
  138. hf_config,
  139. input_height=height,
  140. input_width=width,
  141. )
  142. elif isinstance(image_data, torch.Tensor):
  143. raise NotImplementedError("Embeddings input is not supported yet")
  144. else:
  145. raise TypeError(f"Invalid image type: {type(image_data)}")
  146. vision_config = hf_config.vision_config
  147. if isinstance(vision_config, CLIPVisionConfig):
  148. return input_processor_for_clip(
  149. model_config,
  150. vision_config,
  151. llm_inputs,
  152. image_token_id=hf_config.image_token_index,
  153. image_feature_size_override=image_feature_size,
  154. )
  155. msg = f"Unsupported vision config: {type(vision_config)}"
  156. raise NotImplementedError(msg)
  157. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  158. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
  159. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
  160. class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
  161. def __init__(self,
  162. config: LlavaNextConfig,
  163. vlm_config: VisionLanguageConfig,
  164. cache_config: Optional[CacheConfig] = None,
  165. quant_config: Optional[QuantizationConfig] = None) -> None:
  166. super().__init__()
  167. self.config = config
  168. self.vlm_config = vlm_config
  169. # TODO: Optionally initializes this for supporting embeddings.
  170. self.vision_tower = CLIPVisionModel(config=config.vision_config)
  171. self.multi_modal_projector = LlavaMultiModalProjector(
  172. vision_hidden_size=config.vision_config.hidden_size,
  173. text_hidden_size=config.text_config.hidden_size,
  174. projector_hidden_act=config.projector_hidden_act)
  175. self.quant_config = quant_config
  176. self.language_model = LlamaModel(config.text_config, cache_config,
  177. quant_config)
  178. self.unpadded_vocab_size = config.text_config.vocab_size
  179. self.lm_head = ParallelLMHead(
  180. self.unpadded_vocab_size,
  181. config.text_config.hidden_size,
  182. org_num_embeddings=self.language_model.org_vocab_size,
  183. quant_config=quant_config)
  184. logit_scale = getattr(config, "logit_scale", 1.0)
  185. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  186. config.vocab_size, logit_scale)
  187. self.sampler = Sampler()
  188. self.image_newline = nn.Parameter(
  189. torch.empty(config.text_config.hidden_size))
  190. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  191. if list(data.shape[1:]) != [2]:
  192. raise ValueError(
  193. f"The expected image sizes shape is batch dimension plus "
  194. f"{[2]}. You supplied {data.shape}.")
  195. return data
  196. def _parse_and_validate_image_input(
  197. self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
  198. pixel_values = kwargs.pop("pixel_values", None)
  199. image_sizes = kwargs.pop("image_sizes", None)
  200. if pixel_values is None:
  201. return None
  202. if not isinstance(pixel_values, (torch.Tensor, list)):
  203. raise ValueError("Incorrect type of pixel values. "
  204. f"Got type: {type(pixel_values)}")
  205. if not isinstance(image_sizes, torch.Tensor):
  206. raise ValueError("Incorrect type of image sizes. "
  207. f"Got type: {type(image_sizes)}")
  208. return LlavaNextImagePixelInputs(
  209. type="pixel_values",
  210. data=pixel_values,
  211. image_sizes=self._validate_image_sizes(image_sizes),
  212. )
  213. def _select_image_features(self, image_features: torch.Tensor, *,
  214. strategy: str) -> torch.Tensor:
  215. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  216. if strategy == "default":
  217. return image_features[:, 1:]
  218. elif strategy == "full":
  219. return image_features
  220. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  221. def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
  222. pixel_values: torch.Tensor) -> torch.Tensor:
  223. # NOTE: we skip the step to select the vision feature layer since
  224. # this is already done inside the vision tower
  225. image_features = vision_tower(pixel_values,
  226. self.config.vision_feature_layer)
  227. return self._select_image_features(
  228. image_features,
  229. strategy=self.config.vision_feature_select_strategy,
  230. )
  231. # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
  232. def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
  233. patch_embeddings: torch.Tensor, *,
  234. strategy: str) -> torch.Tensor:
  235. if strategy == "flat":
  236. return patch_embeddings.flatten(0, 1)
  237. if strategy.startswith("spatial"):
  238. height = width = self.config.vision_config.image_size \
  239. // self.config.vision_config.patch_size
  240. base_patch_embeds = patch_embeddings[0]
  241. if height * width != base_patch_embeds.shape[0]:
  242. raise ValueError(
  243. "The number of patches is not consistent with the "
  244. "image size.")
  245. if patch_embeddings.shape[0] > 1:
  246. other_patch_embeds = patch_embeddings[1:]
  247. # image_aspect_ratio == "anyres"
  248. # Note: We follow the "wrong" width/height order
  249. # [ref: PR huggingface/transformers#31588]
  250. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  251. image_size,
  252. self.config.image_grid_pinpoints,
  253. self.config.vision_config.image_size,
  254. )
  255. other_patch_embeds = other_patch_embeds \
  256. .view(num_patch_height, num_patch_width, height, width, -1)
  257. if "unpad" in strategy:
  258. other_patch_embeds = other_patch_embeds \
  259. .permute(4, 0, 2, 1, 3).contiguous() \
  260. .flatten(1, 2).flatten(2, 3)
  261. other_patch_embeds = unpad_image(other_patch_embeds,
  262. image_size)
  263. other_patch_embeds = torch.cat((
  264. other_patch_embeds,
  265. self.image_newline[:, None, None] \
  266. .expand(*other_patch_embeds.shape[:-1], 1) \
  267. .to(other_patch_embeds.device),
  268. ), dim=-1)
  269. other_patch_embeds = other_patch_embeds \
  270. .flatten(1, 2).transpose(0, 1)
  271. else:
  272. other_patch_embeds = other_patch_embeds \
  273. .permute(0, 2, 1, 3, 4).contiguous() \
  274. .flatten(0, 3)
  275. merged_patch_embeddings = torch.cat(
  276. (base_patch_embeds, other_patch_embeds), dim=0)
  277. else:
  278. if "unpad" in strategy:
  279. merged_patch_embeddings = torch.cat(
  280. (base_patch_embeds,
  281. self.image_newline[None] \
  282. .to(base_patch_embeds.device)
  283. ), dim=0)
  284. else:
  285. merged_patch_embeddings = base_patch_embeds
  286. return merged_patch_embeddings
  287. raise ValueError(f"Unexpected patch merge strategy: {strategy}")
  288. def _process_image_pixels(
  289. self,
  290. inputs: LlavaNextImagePixelInputs,
  291. ) -> BatchedTensors:
  292. assert self.vision_tower is not None
  293. pixel_values = inputs["data"]
  294. if isinstance(pixel_values, torch.Tensor):
  295. b, num_patches, c, h, w = pixel_values.shape
  296. stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
  297. stacked_image_features = self._image_pixels_to_features(
  298. self.vision_tower, stacked_pixel_values)
  299. stacked_patch_embeddings = self.multi_modal_projector(
  300. stacked_image_features)
  301. return stacked_patch_embeddings.view(
  302. b, num_patches, *stacked_patch_embeddings.shape[1:])
  303. num_patches_per_batch = [v.shape[0] for v in pixel_values]
  304. stacked_pixel_values = torch.cat(pixel_values)
  305. stacked_image_features = self._image_pixels_to_features(
  306. self.vision_tower, stacked_pixel_values)
  307. return [
  308. self.multi_modal_projector(image_features) for image_features in
  309. torch.split(stacked_image_features, num_patches_per_batch)
  310. ]
  311. def _process_image_input(
  312. self, image_input: LlavaNextImageInputs) -> BatchedTensors:
  313. patch_embeddings = self._process_image_pixels(image_input)
  314. image_sizes = image_input.get("image_sizes")
  315. if image_sizes is None:
  316. batch_size = len(image_input["data"])
  317. vision_config = self.config.vision_config
  318. default_height = default_width = vision_config.image_size
  319. image_sizes = torch.as_tensor([[default_height, default_width]
  320. for _ in range(batch_size)])
  321. return [
  322. self._merge_image_patch_embeddings(image_sizes[i],
  323. patch_features_batch,
  324. strategy="spatial_unpad")
  325. for i, patch_features_batch in enumerate(patch_embeddings)
  326. ]
  327. def forward(
  328. self,
  329. input_ids: torch.Tensor,
  330. positions: torch.Tensor,
  331. kv_caches: List[torch.Tensor],
  332. attn_metadata: AttentionMetadata,
  333. intermediate_tensors: Optional[IntermediateTensors] = None,
  334. **kwargs: object,
  335. ) -> SamplerOutput:
  336. """Run forward pass for LlaVA-NeXT.
  337. One key thing to understand is the `input_ids` already accounts for the
  338. positions of the to-be-inserted image embeddings.
  339. Concretely, consider a text prompt:
  340. "<image>\nUSER: What's the content of the image?\nASSISTANT:".
  341. Tokenizer outputs:
  342. [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
  343. 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
  344. The to-be-inserted image has a size of 576 (24 * 24) along the context
  345. length dimension.
  346. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
  347. 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
  348. 9047, 13566, 29901].
  349. There will be 576 `32000` in the `input_ids`.
  350. (32000 is the token id for `<image>`.)
  351. This way, the `positions` and `attn_metadata` are consistent
  352. with the `input_ids`.
  353. Args:
  354. input_ids: Flattened (concatenated) input_ids corresponding to a
  355. batch.
  356. pixel_values: The pixels in each grid patch for each input image.
  357. Expects a batch with shape `[1, num_patches, 3, h, w]`.
  358. image_sizes: The original `(height, width)` for each input image.
  359. Expects a batch with shape `[1, 2]`.
  360. See also:
  361. Each input maps to huggingface implementation, as follows:
  362. - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690
  363. - `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691
  364. """
  365. image_input = self._parse_and_validate_image_input(**kwargs)
  366. if image_input is not None:
  367. vision_embeddings = self._process_image_input(image_input)
  368. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  369. inputs_embeds = merge_vision_embeddings(
  370. input_ids, inputs_embeds, vision_embeddings,
  371. self.vlm_config.image_token_id)
  372. input_ids = None
  373. else:
  374. inputs_embeds = None
  375. hidden_states = self.language_model(input_ids,
  376. positions,
  377. kv_caches,
  378. attn_metadata,
  379. None,
  380. inputs_embeds=inputs_embeds)
  381. return hidden_states
  382. def compute_logits(self, hidden_states: torch.Tensor,
  383. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  384. logits = self.logits_processor(self.lm_head, hidden_states,
  385. sampling_metadata)
  386. return logits
  387. def sample(
  388. self,
  389. logits: torch.Tensor,
  390. sampling_metadata: SamplingMetadata,
  391. ) -> Optional[SamplerOutput]:
  392. next_tokens = self.sampler(logits, sampling_metadata)
  393. return next_tokens
  394. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  395. # only doing this for language model part for now.
  396. stacked_params_mapping = [
  397. # (param_name, shard_name, shard_id)
  398. ("qkv_proj", "q_proj", "q"),
  399. ("qkv_proj", "k_proj", "k"),
  400. ("qkv_proj", "v_proj", "v"),
  401. ("gate_up_proj", "gate_proj", 0),
  402. ("gate_up_proj", "up_proj", 1),
  403. ]
  404. params_dict = dict(self.named_parameters())
  405. for name, loaded_weight in weights:
  406. if "rotary_emb.inv_freq" in name:
  407. continue
  408. # post_layernorm is not needed in CLIPVisionModel
  409. if "vision_model.post_layernorm" in name:
  410. continue
  411. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  412. if key_to_modify in name:
  413. name = name.replace(key_to_modify, new_key)
  414. use_default_weight_loading = False
  415. if "vision" in name:
  416. if self.vision_tower is not None:
  417. # We only do sharding for language model and
  418. # not vision model for now.
  419. use_default_weight_loading = True
  420. else:
  421. for (param_name, weight_name,
  422. shard_id) in stacked_params_mapping:
  423. if weight_name not in name:
  424. continue
  425. param = params_dict[name.replace(weight_name, param_name)]
  426. weight_loader = param.weight_loader
  427. weight_loader(param, loaded_weight, shard_id)
  428. break
  429. else:
  430. use_default_weight_loading = True
  431. if use_default_weight_loading:
  432. param = params_dict[name]
  433. weight_loader = getattr(param, "weight_loader",
  434. default_weight_loader)
  435. weight_loader(param, loaded_weight)