llava_next.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  2. import torch
  3. import torch.nn as nn
  4. from PIL import Image
  5. from transformers import CLIPVisionConfig, LlavaNextConfig
  6. from transformers.models.llava_next.modeling_llava_next import (
  7. get_anyres_image_grid_shape, unpad_image)
  8. from typing_extensions import NotRequired
  9. from aphrodite.attention import AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, MultiModalConfig
  11. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  12. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  13. from aphrodite.quantization.base_config import (QuantizationConfig)
  14. from aphrodite.modeling.layers.sampler import Sampler
  15. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  16. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  17. from aphrodite.modeling.models.clip import CLIPVisionModel
  18. from aphrodite.modeling.models.llama import LlamaModel
  19. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  20. from aphrodite.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
  21. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  22. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  23. get_clip_patch_grid_length, input_processor_for_clip)
  24. from .interfaces import SupportsVision
  25. from .llava import LlavaMultiModalProjector
  26. from .utils import merge_vision_embeddings
  27. _KEYS_TO_MODIFY_MAPPING = {
  28. "language_model.lm_head": "lm_head",
  29. "language_model.model": "language_model",
  30. }
  31. # Result in the max possible feature size (2x2 grid of 336x336px tiles)
  32. MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
  33. class LlavaNextImagePixelInputs(TypedDict):
  34. type: Literal["pixel_values"]
  35. data: BatchedTensors
  36. """
  37. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  38. Note that `num_patches` may be different for each batch, in which case
  39. the data is passed as a list instead of a batched tensor.
  40. """
  41. image_sizes: NotRequired[torch.Tensor]
  42. """
  43. Shape: `(batch_size, 2)`
  44. This should be in `(height, width)` format.
  45. """
  46. LlavaNextImageInputs = LlavaNextImagePixelInputs
  47. # Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
  48. # NOTE: new_height and new_width are further incremented to properly invert the
  49. # floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
  50. def _get_llava_next_num_unpadded_features(
  51. height: int,
  52. width: int,
  53. npatches: int,
  54. num_patch_height: int,
  55. num_patch_width: int,
  56. ) -> Tuple[int, int]:
  57. current_height = npatches * num_patch_height
  58. current_width = npatches * num_patch_width
  59. current_height = torch.tensor(current_height).to("cuda")
  60. current_width = torch.tensor(current_width).to("cuda")
  61. aspect_ratio: float = width / height
  62. current_aspect_ratio: float = current_width / current_height
  63. if aspect_ratio > current_aspect_ratio:
  64. scale_factor = current_width / width
  65. new_height = int(height * scale_factor)
  66. padding = (current_height - new_height) // 2
  67. current_height -= padding * 2
  68. else:
  69. scale_factor = current_height / height
  70. new_width = int(width * scale_factor)
  71. padding = (current_width - new_width) // 2
  72. current_width -= padding * 2
  73. unpadded_features = current_height * current_width
  74. newline_features = current_height
  75. return (unpadded_features, newline_features)
  76. # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
  77. def get_llava_next_image_feature_size(
  78. hf_config: LlavaNextConfig,
  79. *,
  80. input_height: int,
  81. input_width: int,
  82. ) -> int:
  83. vision_config = hf_config.vision_config
  84. if isinstance(vision_config, CLIPVisionConfig):
  85. num_patches = get_clip_patch_grid_length(
  86. image_size=vision_config.image_size,
  87. patch_size=vision_config.patch_size,
  88. )
  89. base_feature_size = num_patches * num_patches
  90. # Note: We follow the "wrong" width/height order
  91. # [ref: PR huggingface/transformers#31588]
  92. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  93. image_size=(input_height, input_width),
  94. grid_pinpoints=hf_config.image_grid_pinpoints,
  95. patch_size=vision_config.image_size,
  96. )
  97. (
  98. unpadded_feature_size,
  99. newline_feature_size,
  100. ) = _get_llava_next_num_unpadded_features(input_height, input_width,
  101. num_patches,
  102. num_patch_height,
  103. num_patch_width)
  104. return unpadded_feature_size + newline_feature_size + base_feature_size
  105. msg = f"Unsupported vision config: {type(vision_config)}"
  106. raise NotImplementedError(msg)
  107. def get_max_llava_next_image_tokens(ctx: InputContext):
  108. return get_llava_next_image_feature_size(
  109. ctx.get_hf_config(LlavaNextConfig),
  110. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  111. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  112. )
  113. def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
  114. hf_config = ctx.get_hf_config(LlavaNextConfig)
  115. vision_config = hf_config.vision_config
  116. image_feature_size = get_max_llava_next_image_tokens(ctx)
  117. if isinstance(vision_config, CLIPVisionConfig):
  118. seq_data = dummy_seq_data_for_clip(
  119. vision_config,
  120. seq_len,
  121. image_token_id=hf_config.image_token_index,
  122. image_feature_size_override=image_feature_size,
  123. )
  124. mm_data = dummy_image_for_clip(
  125. vision_config,
  126. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  127. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  128. )
  129. return seq_data, mm_data
  130. msg = f"Unsupported vision config: {type(vision_config)}"
  131. raise NotImplementedError(msg)
  132. def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
  133. multi_modal_data = llm_inputs.get("multi_modal_data")
  134. if multi_modal_data is None or "image" not in multi_modal_data:
  135. return llm_inputs
  136. model_config = ctx.model_config
  137. hf_config = ctx.get_hf_config(LlavaNextConfig)
  138. vision_config = hf_config.vision_config
  139. image_data = multi_modal_data["image"]
  140. if isinstance(image_data, Image.Image):
  141. width, height = image_data.size
  142. image_feature_size = get_llava_next_image_feature_size(
  143. hf_config,
  144. input_height=height,
  145. input_width=width,
  146. )
  147. elif isinstance(image_data, torch.Tensor):
  148. raise NotImplementedError("Embeddings input is not supported yet")
  149. else:
  150. raise TypeError(f"Invalid image type: {type(image_data)}")
  151. vision_config = hf_config.vision_config
  152. if isinstance(vision_config, CLIPVisionConfig):
  153. return input_processor_for_clip(
  154. model_config,
  155. vision_config,
  156. llm_inputs,
  157. image_token_id=hf_config.image_token_index,
  158. image_feature_size_override=image_feature_size,
  159. )
  160. msg = f"Unsupported vision config: {type(vision_config)}"
  161. raise NotImplementedError(msg)
  162. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  163. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
  164. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
  165. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
  166. class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
  167. def __init__(self,
  168. config: LlavaNextConfig,
  169. multimodal_config: MultiModalConfig,
  170. cache_config: Optional[CacheConfig] = None,
  171. quant_config: Optional[QuantizationConfig] = None) -> None:
  172. super().__init__()
  173. self.config = config
  174. self.multimodal_config = multimodal_config
  175. # Initialize the vision tower only up to the required feature layer
  176. vision_feature_layer = config.vision_feature_layer
  177. if vision_feature_layer < 0:
  178. num_hidden_layers = config.vision_config.num_hidden_layers \
  179. + vision_feature_layer + 1
  180. else:
  181. num_hidden_layers = vision_feature_layer + 1
  182. # TODO: Optionally initializes this for supporting embeddings.
  183. self.vision_tower = CLIPVisionModel(
  184. config.vision_config, num_hidden_layers_override=num_hidden_layers)
  185. self.multi_modal_projector = LlavaMultiModalProjector(
  186. vision_hidden_size=config.vision_config.hidden_size,
  187. text_hidden_size=config.text_config.hidden_size,
  188. projector_hidden_act=config.projector_hidden_act)
  189. self.quant_config = quant_config
  190. self.language_model = LlamaModel(config.text_config, cache_config,
  191. quant_config)
  192. self.unpadded_vocab_size = config.text_config.vocab_size
  193. self.lm_head = ParallelLMHead(
  194. self.unpadded_vocab_size,
  195. config.text_config.hidden_size,
  196. org_num_embeddings=self.language_model.org_vocab_size,
  197. quant_config=quant_config)
  198. logit_scale = getattr(config, "logit_scale", 1.0)
  199. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  200. config.text_config.vocab_size,
  201. logit_scale)
  202. self.sampler = Sampler()
  203. self.image_newline = nn.Parameter(
  204. torch.empty(config.text_config.hidden_size))
  205. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  206. if list(data.shape[1:]) != [2]:
  207. raise ValueError(
  208. f"The expected image sizes shape is batch dimension plus "
  209. f"{[2]}. You supplied {data.shape}.")
  210. return data
  211. def _validate_pixel_values(
  212. self, data: Union[torch.Tensor, List[torch.Tensor]]
  213. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  214. h = w = self.config.vision_config.image_size
  215. expected_dims = (3, h, w)
  216. def _validate_shape(d: torch.Tensor):
  217. actual_dims = tuple(d.shape[1:])
  218. if actual_dims != expected_dims:
  219. expected_expr = ("num_patches", *map(str, expected_dims))
  220. raise ValueError(
  221. "The expected shape of pixel values in each batch element "
  222. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  223. for d in data:
  224. _validate_shape(d)
  225. return data
  226. def _parse_and_validate_image_input(
  227. self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
  228. pixel_values = kwargs.pop("pixel_values", None)
  229. image_sizes = kwargs.pop("image_sizes", None)
  230. if pixel_values is None:
  231. return None
  232. if not isinstance(pixel_values, (torch.Tensor, list)):
  233. raise ValueError("Incorrect type of pixel values. "
  234. f"Got type: {type(pixel_values)}")
  235. if not isinstance(image_sizes, torch.Tensor):
  236. raise ValueError("Incorrect type of image sizes. "
  237. f"Got type: {type(image_sizes)}")
  238. return LlavaNextImagePixelInputs(
  239. type="pixel_values",
  240. data=self._validate_pixel_values(pixel_values),
  241. image_sizes=self._validate_image_sizes(image_sizes),
  242. )
  243. def _select_image_features(self, image_features: torch.Tensor, *,
  244. strategy: str) -> torch.Tensor:
  245. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  246. if strategy == "default":
  247. return image_features[:, 1:]
  248. elif strategy == "full":
  249. return image_features
  250. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  251. def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
  252. pixel_values: torch.Tensor) -> torch.Tensor:
  253. # NOTE: we skip the step to select the vision feature layer since
  254. # this is already done inside the vision tower
  255. image_features = vision_tower(pixel_values)
  256. return self._select_image_features(
  257. image_features,
  258. strategy=self.config.vision_feature_select_strategy,
  259. )
  260. # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
  261. def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
  262. patch_embeddings: torch.Tensor, *,
  263. strategy: str) -> torch.Tensor:
  264. if strategy == "flat":
  265. return patch_embeddings.flatten(0, 1)
  266. if strategy.startswith("spatial"):
  267. height = width = self.config.vision_config.image_size \
  268. // self.config.vision_config.patch_size
  269. base_patch_embeds = patch_embeddings[0]
  270. if height * width != base_patch_embeds.shape[0]:
  271. raise ValueError(
  272. "The number of patches is not consistent with the "
  273. "image size.")
  274. if patch_embeddings.shape[0] > 1:
  275. other_patch_embeds = patch_embeddings[1:]
  276. # image_aspect_ratio == "anyres"
  277. # Note: We follow the "wrong" width/height order
  278. # [ref: PR huggingface/transformers#31588]
  279. num_patch_width, num_patch_height = get_anyres_image_grid_shape(
  280. image_size,
  281. self.config.image_grid_pinpoints,
  282. self.config.vision_config.image_size,
  283. )
  284. other_patch_embeds = other_patch_embeds \
  285. .view(num_patch_height, num_patch_width, height, width, -1)
  286. if "unpad" in strategy:
  287. other_patch_embeds = other_patch_embeds \
  288. .permute(4, 0, 2, 1, 3).contiguous() \
  289. .flatten(1, 2).flatten(2, 3)
  290. other_patch_embeds = unpad_image(other_patch_embeds,
  291. image_size)
  292. other_patch_embeds = torch.cat((
  293. other_patch_embeds,
  294. self.image_newline[:, None, None] \
  295. .expand(*other_patch_embeds.shape[:-1], 1) \
  296. .to(other_patch_embeds.device),
  297. ), dim=-1)
  298. other_patch_embeds = other_patch_embeds \
  299. .flatten(1, 2).transpose(0, 1)
  300. else:
  301. other_patch_embeds = other_patch_embeds \
  302. .permute(0, 2, 1, 3, 4).contiguous() \
  303. .flatten(0, 3)
  304. merged_patch_embeddings = torch.cat(
  305. (base_patch_embeds, other_patch_embeds), dim=0)
  306. else:
  307. if "unpad" in strategy:
  308. merged_patch_embeddings = torch.cat(
  309. (base_patch_embeds,
  310. self.image_newline[None] \
  311. .to(base_patch_embeds.device)
  312. ), dim=0)
  313. else:
  314. merged_patch_embeddings = base_patch_embeds
  315. return merged_patch_embeddings
  316. raise ValueError(f"Unexpected patch merge strategy: {strategy}")
  317. def _process_image_pixels(
  318. self,
  319. inputs: LlavaNextImagePixelInputs,
  320. ) -> BatchedTensors:
  321. assert self.vision_tower is not None
  322. pixel_values = inputs["data"]
  323. if isinstance(pixel_values, torch.Tensor):
  324. b, num_patches, c, h, w = pixel_values.shape
  325. stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
  326. stacked_image_features = self._image_pixels_to_features(
  327. self.vision_tower, stacked_pixel_values)
  328. stacked_patch_embeddings = self.multi_modal_projector(
  329. stacked_image_features)
  330. return stacked_patch_embeddings.view(
  331. b, num_patches, *stacked_patch_embeddings.shape[1:])
  332. num_patches_per_batch = [v.shape[0] for v in pixel_values]
  333. stacked_pixel_values = torch.cat(pixel_values)
  334. stacked_image_features = self._image_pixels_to_features(
  335. self.vision_tower, stacked_pixel_values)
  336. return [
  337. self.multi_modal_projector(image_features) for image_features in
  338. torch.split(stacked_image_features, num_patches_per_batch)
  339. ]
  340. def _process_image_input(
  341. self, image_input: LlavaNextImageInputs) -> BatchedTensors:
  342. patch_embeddings = self._process_image_pixels(image_input)
  343. image_sizes = image_input.get("image_sizes")
  344. if image_sizes is None:
  345. batch_size = len(image_input["data"])
  346. vision_config = self.config.vision_config
  347. default_height = default_width = vision_config.image_size
  348. image_sizes = torch.as_tensor([[default_height, default_width]
  349. for _ in range(batch_size)])
  350. return [
  351. self._merge_image_patch_embeddings(image_sizes[i],
  352. patch_features_batch,
  353. strategy="spatial_unpad")
  354. for i, patch_features_batch in enumerate(patch_embeddings)
  355. ]
  356. def forward(
  357. self,
  358. input_ids: torch.Tensor,
  359. positions: torch.Tensor,
  360. kv_caches: List[torch.Tensor],
  361. attn_metadata: AttentionMetadata,
  362. intermediate_tensors: Optional[IntermediateTensors] = None,
  363. **kwargs: object,
  364. ) -> SamplerOutput:
  365. """Run forward pass for LlaVA-NeXT.
  366. One key thing to understand is the `input_ids` already accounts for the
  367. positions of the to-be-inserted image embeddings.
  368. Concretely, consider a text prompt:
  369. `"A chat between a curious human and an artificial intelligence
  370. assistant. The assistant gives helpful, detailed, and polite answers to
  371. the human's questions.
  372. USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
  373. Tokenizer outputs:
  374. `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
  375. 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
  376. 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
  377. 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
  378. 9047, 13566, 29901]`.
  379. To reserve space in KV cache, we have to insert placeholder tokens
  380. before they are inputted to the model, so the input processor prepends
  381. additional image tokens (denoted as `32000`), resulting in:
  382. `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
  383. 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
  384. 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
  385. 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
  386. 319, 1799, 9047, 13566, 29901]`.
  387. Unlike in LLaVA-1.5, the number of image tokens inputted to the language
  388. model depends on the original size of the input image. Including the
  389. original image token in the input, the required number of image tokens
  390. is given by :func:`get_llava_next_image_feature_size`.
  391. This way, the `positions` and `attn_metadata` are consistent
  392. with the `input_ids`.
  393. Args:
  394. input_ids: Flattened (concatenated) input_ids corresponding to a
  395. batch.
  396. pixel_values: The pixels in each grid patch for each input image.
  397. image_sizes: The original `(height, width)` for each input image.
  398. See also:
  399. :class:`LlavaNextImageInputs`
  400. """
  401. image_input = self._parse_and_validate_image_input(**kwargs)
  402. if image_input is not None:
  403. vision_embeddings = self._process_image_input(image_input)
  404. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  405. inputs_embeds = merge_vision_embeddings(
  406. input_ids, inputs_embeds, vision_embeddings,
  407. self.config.image_token_index)
  408. input_ids = None
  409. else:
  410. inputs_embeds = None
  411. hidden_states = self.language_model(input_ids,
  412. positions,
  413. kv_caches,
  414. attn_metadata,
  415. None,
  416. inputs_embeds=inputs_embeds)
  417. return hidden_states
  418. def compute_logits(self, hidden_states: torch.Tensor,
  419. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  420. logits = self.logits_processor(self.lm_head, hidden_states,
  421. sampling_metadata)
  422. return logits
  423. def sample(
  424. self,
  425. logits: torch.Tensor,
  426. sampling_metadata: SamplingMetadata,
  427. ) -> Optional[SamplerOutput]:
  428. next_tokens = self.sampler(logits, sampling_metadata)
  429. return next_tokens
  430. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  431. # only doing this for language model part for now.
  432. stacked_params_mapping = [
  433. # (param_name, shard_name, shard_id)
  434. ("qkv_proj", "q_proj", "q"),
  435. ("qkv_proj", "k_proj", "k"),
  436. ("qkv_proj", "v_proj", "v"),
  437. ("gate_up_proj", "gate_proj", 0),
  438. ("gate_up_proj", "up_proj", 1),
  439. ]
  440. params_dict = dict(self.named_parameters())
  441. for name, loaded_weight in weights:
  442. if "rotary_emb.inv_freq" in name:
  443. continue
  444. # post_layernorm is not needed in CLIPVisionModel
  445. if "vision_model.post_layernorm" in name:
  446. continue
  447. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  448. if key_to_modify in name:
  449. name = name.replace(key_to_modify, new_key)
  450. use_default_weight_loading = False
  451. if "vision" in name:
  452. if self.vision_tower is not None:
  453. # We only do sharding for language model and
  454. # not vision model for now.
  455. use_default_weight_loading = True
  456. else:
  457. for (param_name, weight_name,
  458. shard_id) in stacked_params_mapping:
  459. if weight_name not in name:
  460. continue
  461. param = params_dict[name.replace(weight_name, param_name)]
  462. weight_loader = param.weight_loader
  463. weight_loader(param, loaded_weight, shard_id)
  464. break
  465. else:
  466. use_default_weight_loading = True
  467. if use_default_weight_loading and name in params_dict:
  468. param = params_dict[name]
  469. weight_loader = getattr(param, "weight_loader",
  470. default_weight_loader)
  471. weight_loader(param, loaded_weight)