1
0

llava_next.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. import itertools
  2. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  3. TypedDict, Union)
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
  8. from transformers.models.llava_next.modeling_llava_next import (
  9. get_anyres_image_grid_shape, unpad_image)
  10. from typing_extensions import NotRequired
  11. from aphrodite.attention import AttentionMetadata
  12. from aphrodite.common.config import CacheConfig, MultiModalConfig
  13. from aphrodite.common.sequence import IntermediateTensors
  14. from aphrodite.common.utils import is_list_of
  15. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  16. from aphrodite.modeling.layers.sampler import SamplerOutput
  17. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  18. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  19. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  20. from aphrodite.quantization.base_config import QuantizationConfig
  21. from .clip import (CLIPVisionModel, dummy_image_for_clip,
  22. dummy_seq_data_for_clip, get_clip_image_feature_size,
  23. get_clip_patch_grid_length, input_processor_for_clip)
  24. from .interfaces import SupportsMultiModal
  25. from .llava import LlavaMultiModalProjector
  26. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  27. dummy_seq_data_for_siglip, get_siglip_image_feature_size,
  28. get_siglip_patch_grid_length, input_processor_for_siglip)
  29. from .utils import (filter_weights, flatten_bn,
  30. init_aphrodite_registered_model,
  31. merge_multimodal_embeddings)
  32. _KEYS_TO_MODIFY_MAPPING = {
  33. "language_model.lm_head": "lm_head",
  34. "language_model.model": "language_model",
  35. }
  36. # Result in the max possible feature size (2x2 grid of 336x336px tiles)
  37. MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
  38. class LlavaNextImagePixelInputs(TypedDict):
  39. type: Literal["pixel_values"]
  40. data: Union[torch.Tensor, List[torch.Tensor]]
  41. """
  42. Shape:
  43. `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
  44. Note that `num_patches` may be different per batch and image,
  45. in which case the data is passed as a list instead of a batched tensor.
  46. """
  47. image_sizes: NotRequired[torch.Tensor]
  48. """
  49. Shape: `(batch_size * num_images, 2)`
  50. This should be in `(height, width)` format.
  51. """
  52. class LlavaNextImageEmbeddingInputs(TypedDict):
  53. type: Literal["image_embeds"]
  54. data: torch.Tensor
  55. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  56. `hidden_size` must match the hidden size of language model backbone.
  57. """
  58. LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
  59. LlavaNextImageEmbeddingInputs]
  60. # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
  61. def _get_llava_next_num_unpadded_features(
  62. original_height: int,
  63. original_width: int,
  64. npatches: int,
  65. num_patch_height: int,
  66. num_patch_width: int,
  67. ) -> Tuple[int, int]:
  68. current_height = npatches * num_patch_height
  69. current_width = npatches * num_patch_width
  70. original_aspect_ratio = original_width / original_height
  71. current_aspect_ratio = current_width / current_height
  72. if original_aspect_ratio > current_aspect_ratio:
  73. scale_factor = current_width / original_width
  74. new_height = int(original_height * scale_factor)
  75. padding = (current_height - new_height) // 2
  76. current_height -= 2 * padding
  77. else:
  78. scale_factor = current_height / original_height
  79. new_width = int(original_width * scale_factor)
  80. padding = (current_width - new_width) // 2
  81. current_width -= 2 * padding
  82. unpadded_features = current_height * current_width
  83. newline_features = current_height
  84. return (unpadded_features, newline_features)
  85. # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
  86. def get_llava_next_image_feature_size(
  87. hf_config: LlavaNextConfig,
  88. *,
  89. input_height: int,
  90. input_width: int,
  91. ) -> int:
  92. vision_config = hf_config.vision_config
  93. if isinstance(vision_config, CLIPVisionConfig):
  94. num_patches = get_clip_patch_grid_length(
  95. image_size=vision_config.image_size,
  96. patch_size=vision_config.patch_size,
  97. )
  98. base_feature_size = get_clip_image_feature_size(vision_config)
  99. elif isinstance(vision_config, SiglipVisionConfig):
  100. num_patches = get_siglip_patch_grid_length(
  101. image_size=vision_config.image_size,
  102. patch_size=vision_config.patch_size,
  103. )
  104. base_feature_size = get_siglip_image_feature_size(vision_config)
  105. else:
  106. msg = f"Unsupported vision config: {type(vision_config)}"
  107. raise NotImplementedError(msg)
  108. strategy = hf_config.vision_feature_select_strategy
  109. if strategy == "default":
  110. base_feature_size -= 1
  111. elif strategy == "full":
  112. pass
  113. else:
  114. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  115. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  116. image_size=(input_height, input_width),
  117. grid_pinpoints=hf_config.image_grid_pinpoints,
  118. patch_size=vision_config.image_size,
  119. )
  120. (
  121. unpadded_feature_size,
  122. newline_feature_size,
  123. ) = _get_llava_next_num_unpadded_features(input_height, input_width,
  124. num_patches, num_patch_height,
  125. num_patch_width)
  126. return unpadded_feature_size + newline_feature_size + base_feature_size
  127. def get_max_llava_next_image_tokens(ctx: InputContext):
  128. return get_llava_next_image_feature_size(
  129. ctx.get_hf_config(LlavaNextConfig),
  130. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  131. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  132. )
  133. def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
  134. mm_counts: Mapping[str, int]):
  135. hf_config = ctx.get_hf_config(LlavaNextConfig)
  136. vision_config = hf_config.vision_config
  137. num_images = mm_counts["image"]
  138. image_feature_size = get_max_llava_next_image_tokens(ctx)
  139. if isinstance(vision_config, CLIPVisionConfig):
  140. seq_data = dummy_seq_data_for_clip(
  141. vision_config,
  142. seq_len,
  143. num_images,
  144. image_token_id=hf_config.image_token_index,
  145. image_feature_size_override=image_feature_size,
  146. )
  147. mm_data = dummy_image_for_clip(
  148. vision_config,
  149. num_images,
  150. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  151. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  152. )
  153. return seq_data, mm_data
  154. elif isinstance(vision_config, SiglipVisionConfig):
  155. seq_data = dummy_seq_data_for_siglip(
  156. vision_config,
  157. seq_len,
  158. num_images,
  159. image_token_id=hf_config.image_token_index,
  160. image_feature_size_override=image_feature_size,
  161. )
  162. mm_data = dummy_image_for_siglip(
  163. vision_config,
  164. num_images,
  165. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  166. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  167. )
  168. return seq_data, mm_data
  169. msg = f"Unsupported vision config: {type(vision_config)}"
  170. raise NotImplementedError(msg)
  171. def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
  172. multi_modal_data = llm_inputs.get("multi_modal_data")
  173. if multi_modal_data is None or "image" not in multi_modal_data:
  174. return llm_inputs
  175. model_config = ctx.model_config
  176. hf_config = ctx.get_hf_config(LlavaNextConfig)
  177. vision_config = hf_config.vision_config
  178. image_data = multi_modal_data["image"]
  179. if isinstance(image_data, Image.Image):
  180. width, height = image_data.size
  181. image_feature_size = get_llava_next_image_feature_size(
  182. hf_config,
  183. input_height=height,
  184. input_width=width,
  185. )
  186. elif is_list_of(image_data, Image.Image):
  187. image_feature_size = [
  188. get_llava_next_image_feature_size(hf_config,
  189. input_height=img.height,
  190. input_width=img.width)
  191. for img in image_data
  192. ]
  193. elif isinstance(image_data, torch.Tensor):
  194. num_images, image_feature_size, hidden_size = image_data.shape
  195. elif is_list_of(image_data, torch.Tensor):
  196. image_feature_size = [item.shape[1] for item in image_data]
  197. else:
  198. raise TypeError(f"Invalid image type: {type(image_data)}")
  199. vision_config = hf_config.vision_config
  200. if isinstance(vision_config, CLIPVisionConfig):
  201. return input_processor_for_clip(
  202. model_config,
  203. vision_config,
  204. llm_inputs,
  205. image_token_id=hf_config.image_token_index,
  206. image_feature_size_override=image_feature_size,
  207. )
  208. elif isinstance(vision_config, SiglipVisionConfig):
  209. return input_processor_for_siglip(
  210. model_config,
  211. vision_config,
  212. llm_inputs,
  213. image_token_id=hf_config.image_token_index,
  214. image_feature_size_override=image_feature_size,
  215. )
  216. msg = f"Unsupported vision config: {type(vision_config)}"
  217. raise NotImplementedError(msg)
  218. def _init_vision_tower(hf_config: LlavaNextConfig):
  219. vision_config = hf_config.vision_config
  220. # Initialize the vision tower only up to the required feature layer
  221. vision_feature_layer = hf_config.vision_feature_layer
  222. if vision_feature_layer < 0:
  223. num_hidden_layers = hf_config.vision_config.num_hidden_layers \
  224. + vision_feature_layer + 1
  225. else:
  226. num_hidden_layers = vision_feature_layer + 1
  227. if isinstance(vision_config, CLIPVisionConfig):
  228. return CLIPVisionModel(
  229. vision_config,
  230. num_hidden_layers_override=num_hidden_layers,
  231. )
  232. elif isinstance(vision_config, SiglipVisionConfig):
  233. return SiglipVisionModel(
  234. vision_config,
  235. num_hidden_layers_override=num_hidden_layers,
  236. )
  237. msg = f"Unsupported vision config: {type(vision_config)}"
  238. raise NotImplementedError(msg)
  239. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  240. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
  241. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
  242. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
  243. class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
  244. def __init__(self,
  245. config: LlavaNextConfig,
  246. multimodal_config: MultiModalConfig,
  247. cache_config: Optional[CacheConfig] = None,
  248. quant_config: Optional[QuantizationConfig] = None) -> None:
  249. super().__init__()
  250. self.config = config
  251. self.multimodal_config = multimodal_config
  252. # TODO: Optionally initializes this for supporting embeddings.
  253. self.vision_tower = _init_vision_tower(config)
  254. self.multi_modal_projector = LlavaMultiModalProjector(
  255. vision_hidden_size=config.vision_config.hidden_size,
  256. text_hidden_size=config.text_config.hidden_size,
  257. projector_hidden_act=config.projector_hidden_act)
  258. self.language_model = init_aphrodite_registered_model(
  259. config.text_config, cache_config, quant_config)
  260. self.image_newline = nn.Parameter(
  261. torch.empty(config.text_config.hidden_size))
  262. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  263. expected_dims = (2, )
  264. def _validate_shape(d: torch.Tensor):
  265. actual_dims = tuple(d.shape)
  266. if actual_dims != expected_dims:
  267. expected_expr = str(expected_dims)
  268. raise ValueError(
  269. f"The expected shape of image sizes per image per batch "
  270. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  271. for d in data:
  272. _validate_shape(d)
  273. return data
  274. def _validate_pixel_values(
  275. self, data: Union[torch.Tensor, List[torch.Tensor]]
  276. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  277. h = w = self.config.vision_config.image_size
  278. expected_dims = (3, h, w)
  279. def _validate_shape(d: torch.Tensor):
  280. actual_dims = tuple(d.shape[1:])
  281. if actual_dims != expected_dims:
  282. expected_expr = ("num_patches", *map(str, expected_dims))
  283. raise ValueError(
  284. "The expected shape of pixel values per image per batch "
  285. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  286. for d in data:
  287. _validate_shape(d)
  288. return data
  289. def _parse_and_validate_image_input(
  290. self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
  291. pixel_values = kwargs.pop("pixel_values", None)
  292. image_sizes = kwargs.pop("image_sizes", None)
  293. image_embeds = kwargs.pop("image_embeds", None)
  294. if pixel_values is None and image_embeds is None:
  295. return None
  296. if pixel_values is not None:
  297. if not isinstance(pixel_values, (torch.Tensor, list)):
  298. raise ValueError("Incorrect type of pixel values. "
  299. f"Got type: {type(pixel_values)}")
  300. if not isinstance(image_sizes, (torch.Tensor, list)):
  301. raise ValueError("Incorrect type of image sizes. "
  302. f"Got type: {type(image_sizes)}")
  303. return LlavaNextImagePixelInputs(
  304. type="pixel_values",
  305. data=self._validate_pixel_values(flatten_bn(pixel_values)),
  306. image_sizes=self._validate_image_sizes(
  307. flatten_bn(image_sizes, concat=True)),
  308. )
  309. if image_embeds is not None:
  310. if not isinstance(image_embeds, torch.Tensor):
  311. raise ValueError("Incorrect type of image embeds. "
  312. f"Got type: {type(image_embeds)}")
  313. return LlavaNextImageEmbeddingInputs(
  314. type="image_embeds",
  315. data=flatten_bn(image_embeds),
  316. )
  317. raise AssertionError("This line should be unreachable.")
  318. def _select_image_features(self, image_features: torch.Tensor, *,
  319. strategy: str) -> torch.Tensor:
  320. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  321. if strategy == "default":
  322. return image_features[:, 1:]
  323. elif strategy == "full":
  324. return image_features
  325. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  326. def _image_pixels_to_features(
  327. self,
  328. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  329. pixel_values: torch.Tensor,
  330. ) -> torch.Tensor:
  331. # NOTE: we skip the step to select the vision feature layer since
  332. # this is already done inside the vision tower
  333. image_features = vision_tower(pixel_values)
  334. return self._select_image_features(
  335. image_features,
  336. strategy=self.config.vision_feature_select_strategy,
  337. )
  338. # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
  339. def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
  340. patch_embeddings: torch.Tensor, *,
  341. strategy: str) -> torch.Tensor:
  342. if strategy == "flat":
  343. return patch_embeddings.flatten(0, 1)
  344. if strategy.startswith("spatial"):
  345. height = width = self.config.vision_config.image_size \
  346. // self.config.vision_config.patch_size
  347. base_patch_embeds = patch_embeddings[0]
  348. if height * width != base_patch_embeds.shape[0]:
  349. raise ValueError(
  350. "The number of patches is not consistent with the "
  351. "image size.")
  352. if patch_embeddings.shape[0] > 1:
  353. other_patch_embeds = patch_embeddings[1:]
  354. # Move to CPU to avoid floating-point errors
  355. orig_height, orig_width = image_size.tolist()
  356. # image_aspect_ratio == "anyres"
  357. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  358. (orig_height, orig_width),
  359. self.config.image_grid_pinpoints,
  360. self.config.vision_config.image_size,
  361. )
  362. num_patches = num_patch_height * num_patch_width
  363. # Image patches might be padded for batch processing
  364. other_patch_embeds = other_patch_embeds[:num_patches] \
  365. .view(num_patch_height, num_patch_width, height, width, -1)
  366. if "unpad" in strategy:
  367. other_patch_embeds = other_patch_embeds \
  368. .permute(4, 0, 2, 1, 3).contiguous() \
  369. .flatten(1, 2).flatten(2, 3)
  370. other_patch_embeds = unpad_image(other_patch_embeds,
  371. (orig_height, orig_width))
  372. other_patch_embeds = torch.cat((
  373. other_patch_embeds,
  374. self.image_newline[:, None, None] \
  375. .expand(*other_patch_embeds.shape[:-1], 1) \
  376. .to(other_patch_embeds.device),
  377. ), dim=-1)
  378. other_patch_embeds = other_patch_embeds \
  379. .flatten(1, 2).transpose(0, 1)
  380. else:
  381. other_patch_embeds = other_patch_embeds \
  382. .permute(0, 2, 1, 3, 4).contiguous() \
  383. .flatten(0, 3)
  384. merged_patch_embeddings = torch.cat(
  385. (base_patch_embeds, other_patch_embeds), dim=0)
  386. else:
  387. if "unpad" in strategy:
  388. merged_patch_embeddings = torch.cat(
  389. (base_patch_embeds,
  390. self.image_newline[None] \
  391. .to(base_patch_embeds.device)
  392. ), dim=0)
  393. else:
  394. merged_patch_embeddings = base_patch_embeds
  395. return merged_patch_embeddings
  396. raise ValueError(f"Unexpected patch merge strategy: {strategy}")
  397. def _process_image_pixels(
  398. self,
  399. inputs: LlavaNextImagePixelInputs,
  400. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  401. assert self.vision_tower is not None
  402. pixel_values = inputs["data"]
  403. if isinstance(pixel_values, torch.Tensor):
  404. b, num_patches, c, h, w = pixel_values.shape
  405. stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
  406. stacked_image_features = self._image_pixels_to_features(
  407. self.vision_tower, stacked_pixel_values)
  408. stacked_patch_embeddings = self.multi_modal_projector(
  409. stacked_image_features)
  410. return stacked_patch_embeddings.view(
  411. b, num_patches, *stacked_patch_embeddings.shape[1:])
  412. num_patches_per_batch = [v.shape[0] for v in pixel_values]
  413. stacked_pixel_values = torch.cat(pixel_values)
  414. stacked_image_features = self._image_pixels_to_features(
  415. self.vision_tower, stacked_pixel_values)
  416. return [
  417. self.multi_modal_projector(image_features) for image_features in
  418. torch.split(stacked_image_features, num_patches_per_batch)
  419. ]
  420. def _process_image_input(
  421. self,
  422. image_input: LlavaNextImageInputs,
  423. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  424. if image_input["type"] == "image_embeds":
  425. return [image_input["data"]]
  426. patch_embeddings = self._process_image_pixels(image_input)
  427. image_sizes = image_input.get("image_sizes")
  428. if image_sizes is None:
  429. batch_size = len(image_input["data"])
  430. vision_config = self.config.vision_config
  431. default_height = default_width = vision_config.image_size
  432. image_sizes = torch.as_tensor([[default_height, default_width]
  433. for _ in range(batch_size)])
  434. return [
  435. self._merge_image_patch_embeddings(image_sizes[i],
  436. patch_features_batch,
  437. strategy="spatial_unpad")
  438. for i, patch_features_batch in enumerate(patch_embeddings)
  439. ]
  440. def forward(
  441. self,
  442. input_ids: torch.Tensor,
  443. positions: torch.Tensor,
  444. kv_caches: List[torch.Tensor],
  445. attn_metadata: AttentionMetadata,
  446. intermediate_tensors: Optional[IntermediateTensors] = None,
  447. **kwargs: object,
  448. ) -> SamplerOutput:
  449. """Run forward pass for LlaVA-NeXT.
  450. One key thing to understand is the `input_ids` already accounts for the
  451. positions of the to-be-inserted image embeddings.
  452. Concretely, consider a text prompt:
  453. `"A chat between a curious human and an artificial intelligence
  454. assistant. The assistant gives helpful, detailed, and polite answers to
  455. the human's questions.
  456. USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
  457. Tokenizer outputs:
  458. `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
  459. 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
  460. 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
  461. 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
  462. 9047, 13566, 29901]`.
  463. To reserve space in KV cache, we have to insert placeholder tokens
  464. before they are inputted to the model, so the input processor prepends
  465. additional image tokens (denoted as `32000`), resulting in:
  466. `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
  467. 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
  468. 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
  469. 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
  470. 319, 1799, 9047, 13566, 29901]`.
  471. Unlike in LLaVA-1.5, the number of image tokens inputted to the language
  472. model depends on the original size of the input image. Including the
  473. original image token in the input, the required number of image tokens
  474. is given by :func:`get_llava_next_image_feature_size`.
  475. This way, the `positions` and `attn_metadata` are consistent
  476. with the `input_ids`.
  477. Args:
  478. input_ids: Flattened (concatenated) input_ids corresponding to a
  479. batch.
  480. pixel_values: The pixels in each grid patch for each input image.
  481. image_sizes: The original `(height, width)` for each input image.
  482. See also:
  483. :class:`LlavaNextImageInputs`
  484. """
  485. image_input = self._parse_and_validate_image_input(**kwargs)
  486. if image_input is not None:
  487. vision_embeddings = self._process_image_input(image_input)
  488. inputs_embeds = self.language_model.model.get_input_embeddings(
  489. input_ids)
  490. inputs_embeds = merge_multimodal_embeddings(
  491. input_ids, inputs_embeds, vision_embeddings,
  492. self.config.image_token_index)
  493. input_ids = None
  494. else:
  495. inputs_embeds = None
  496. hidden_states = self.language_model.model(input_ids,
  497. positions,
  498. kv_caches,
  499. attn_metadata,
  500. None,
  501. inputs_embeds=inputs_embeds)
  502. return hidden_states
  503. def compute_logits(
  504. self,
  505. hidden_states: torch.Tensor,
  506. sampling_metadata: SamplingMetadata,
  507. ) -> Optional[torch.Tensor]:
  508. return self.language_model.compute_logits(hidden_states,
  509. sampling_metadata)
  510. def sample(
  511. self,
  512. logits: torch.Tensor,
  513. sampling_metadata: SamplingMetadata,
  514. ) -> Optional[SamplerOutput]:
  515. return self.language_model.sample(logits, sampling_metadata)
  516. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  517. # prepare weight iterators for components
  518. vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
  519. weights, 4)
  520. # load vision encoder
  521. vit_weights = filter_weights(vit_weights, "vision_tower")
  522. self.vision_tower.load_weights(vit_weights)
  523. # load mlp projector
  524. mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
  525. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  526. for name, loaded_weight in mlp_weights:
  527. param = mlp_params_dict[name]
  528. weight_loader = getattr(param, "weight_loader",
  529. default_weight_loader)
  530. weight_loader(param, loaded_weight)
  531. # load newline
  532. newline_weights = filter_weights(newline_weights, "image_newline")
  533. for name, loaded_weight in newline_weights:
  534. assert name == ""
  535. param = self.image_newline
  536. weight_loader = getattr(param, "weight_loader",
  537. default_weight_loader)
  538. weight_loader(param, loaded_weight)
  539. # load llm backbone
  540. llm_weights = filter_weights(llm_weights, "language_model")
  541. self.language_model.load_weights(llm_weights)