llava_onevision.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. import math
  2. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  3. TypedDict, Union)
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from PIL import Image
  8. from transformers import (CLIPVisionConfig, LlavaOnevisionConfig,
  9. SiglipVisionConfig)
  10. from transformers.models.llava_onevision.modeling_llava_onevision import (
  11. get_anyres_image_grid_shape, unpad_image)
  12. from typing_extensions import NotRequired
  13. from aphrodite.attention import AttentionMetadata
  14. from aphrodite.common.config import CacheConfig, MultiModalConfig
  15. from aphrodite.common.sequence import IntermediateTensors
  16. from aphrodite.common.utils import is_list_of
  17. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  18. from aphrodite.modeling.layers.activation import get_act_fn
  19. from aphrodite.modeling.layers.sampler import SamplerOutput
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  23. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  24. repeat_and_pad_placeholder_tokens)
  25. from aphrodite.quantization.base_config import QuantizationConfig
  26. from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
  27. dummy_video_for_clip, get_clip_image_feature_size,
  28. get_clip_patch_grid_length, input_processor_for_clip)
  29. from .interfaces import SupportsMultiModal
  30. from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
  31. dummy_video_for_siglip, get_siglip_image_feature_size,
  32. get_siglip_patch_grid_length, input_processor_for_siglip)
  33. from .utils import (flatten_bn, group_weights_with_prefix,
  34. init_aphrodite_registered_model,
  35. merge_multimodal_embeddings)
  36. # Result in the max possible feature size (2x2 grid of 336x336px tiles)
  37. MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
  38. # For profile run
  39. _MAX_FRAMES_PER_VIDEO = 16
  40. _MAX_NUM_VIDEOS = 1
  41. class LlavaOnevisionVideoPixelInputs(TypedDict):
  42. type: Literal["pixel_values_videos"]
  43. data: Union[torch.Tensor, List[torch.Tensor]]
  44. """
  45. Shape: `(batch_size, num_frames, num_channels, height, width)`
  46. Note that `num_frames` may be different for each batch, in which case
  47. the data is passed as a list instead of a batched tensor.
  48. Note that it only supports one video input for one batch.
  49. """
  50. class LlavaOnevisionImagePixelInputs(TypedDict):
  51. type: Literal["pixel_values"]
  52. data: Union[torch.Tensor, List[torch.Tensor]]
  53. """
  54. Shape:
  55. `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
  56. Note that `num_patches` may be different per batch and image,
  57. in which case the data is passed as a list instead of a batched tensor.
  58. """
  59. image_sizes: NotRequired[torch.Tensor]
  60. """
  61. Shape: `(batch_size * num_images, 2)`
  62. This should be in `(height, width)` format.
  63. """
  64. class LlavaOnevisionImageEmbeddingInputs(TypedDict):
  65. type: Literal["image_embeds"]
  66. data: torch.Tensor
  67. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  68. `hidden_size` must match the hidden size of language model backbone.
  69. """
  70. LlavaOnevisionImageInputs = Union[
  71. LlavaOnevisionImagePixelInputs, LlavaOnevisionImageEmbeddingInputs
  72. ]
  73. LlavaOnevisionMultiInputs = Union[
  74. LlavaOnevisionImageInputs, LlavaOnevisionVideoPixelInputs
  75. ]
  76. def _get_llava_onevision_image_unppaded_feature_size(
  77. height, width, patches, scale_height, scale_width
  78. ):
  79. current_height = patches * scale_height
  80. current_width = patches * scale_width
  81. original_aspect_ratio = width / height
  82. current_aspect_ratio = current_width / current_height
  83. if original_aspect_ratio > current_aspect_ratio:
  84. new_height = int(height * (current_width / width))
  85. padding = (current_height - new_height) // 2
  86. current_height -= padding * 2
  87. else:
  88. new_width = int(width * (current_height / height))
  89. padding = (current_width - new_width) // 2
  90. current_width -= padding * 2
  91. unpadded_features = current_height * current_width
  92. newline_features = current_height
  93. ratio = math.sqrt(current_height * current_width / (9 * patches**2))
  94. if ratio > 1.1:
  95. unpadded_features = int(current_height // ratio) * int(
  96. current_width // ratio
  97. )
  98. newline_features = int(current_height // ratio)
  99. return (unpadded_features, newline_features)
  100. def get_llava_onevision_image_feature_size(
  101. hf_config: LlavaOnevisionConfig,
  102. *,
  103. input_height: int,
  104. input_width: int,
  105. ) -> int:
  106. vision_config = hf_config.vision_config
  107. if isinstance(vision_config, CLIPVisionConfig):
  108. num_patches = get_clip_patch_grid_length(
  109. image_size=vision_config.image_size,
  110. patch_size=vision_config.patch_size,
  111. )
  112. base_feature_size = get_clip_image_feature_size(vision_config)
  113. elif isinstance(vision_config, SiglipVisionConfig):
  114. num_patches = get_siglip_patch_grid_length(
  115. image_size=vision_config.image_size,
  116. patch_size=vision_config.patch_size,
  117. )
  118. base_feature_size = get_siglip_image_feature_size(vision_config)
  119. else:
  120. msg = f"Unsupported vision config: {type(vision_config)}"
  121. raise NotImplementedError(msg)
  122. strategy = hf_config.vision_feature_select_strategy
  123. if strategy == "default":
  124. base_feature_size -= 1
  125. elif strategy == "full":
  126. pass
  127. else:
  128. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  129. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  130. image_size=(input_height, input_width),
  131. grid_pinpoints=hf_config.image_grid_pinpoints,
  132. patch_size=vision_config.image_size,
  133. )
  134. (
  135. unpadded_feature_size,
  136. newline_feature_size,
  137. ) = _get_llava_onevision_image_unppaded_feature_size(
  138. input_height,
  139. input_width,
  140. num_patches,
  141. num_patch_height,
  142. num_patch_width,
  143. )
  144. return unpadded_feature_size + newline_feature_size + base_feature_size
  145. def get_max_llava_onevision_image_tokens(ctx: InputContext):
  146. return get_llava_onevision_image_feature_size(
  147. ctx.get_hf_config(LlavaOnevisionConfig),
  148. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  149. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  150. )
  151. def get_llava_onevision_video_frame_feature_size(
  152. hf_config: LlavaOnevisionConfig
  153. ) -> int:
  154. # Support both CLIPVisionConfig and SiglipVisionConfig
  155. image_size = hf_config.vision_config.image_size
  156. patch_size = hf_config.vision_config.patch_size
  157. spatial_pool_stride = (
  158. hf_config.spatial_pool_stride
  159. if hasattr(hf_config, "spatial_pool_stride")
  160. else 2
  161. )
  162. height = width = image_size // patch_size
  163. return math.ceil(height / spatial_pool_stride) * math.ceil(
  164. width / spatial_pool_stride
  165. )
  166. def get_llava_onevision_video_tokens(ctx: InputContext, num_frames: int) -> int:
  167. hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
  168. # TODO: support configuring (not supported by HF right now)
  169. num_token_image_newline = 1
  170. tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
  171. video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
  172. return video_feature_size
  173. def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
  174. return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
  175. def dummy_data_for_llava_onevision(
  176. ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
  177. ):
  178. hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
  179. vision_config = hf_config.vision_config
  180. # TODO: support multiple videos
  181. num_videos = mm_counts["video"]
  182. if num_videos > _MAX_NUM_VIDEOS:
  183. raise NotImplementedError(
  184. f"Only {_MAX_NUM_VIDEOS} videos are supported"
  185. )
  186. # TODO: support configuring the number of frames
  187. num_frames = _MAX_FRAMES_PER_VIDEO
  188. video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
  189. if isinstance(vision_config, CLIPVisionConfig):
  190. seq_data = dummy_seq_data_for_clip(
  191. vision_config,
  192. seq_len,
  193. num_videos,
  194. image_token_id=hf_config.video_token_index,
  195. image_feature_size_override=video_feature_size,
  196. )
  197. mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames)
  198. return seq_data, mm_data
  199. elif isinstance(vision_config, SiglipVisionConfig):
  200. seq_data = dummy_seq_data_for_siglip(
  201. vision_config,
  202. seq_len,
  203. num_videos,
  204. image_token_id=hf_config.video_token_index,
  205. image_feature_size_override=video_feature_size,
  206. )
  207. mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames)
  208. return seq_data, mm_data
  209. msg = f"Unsupported vision config: {type(vision_config)}"
  210. raise NotImplementedError(msg)
  211. def input_processor_when_multimodal_input_image(
  212. ctx: InputContext, llm_inputs: LLMInputs
  213. ):
  214. multi_modal_data = llm_inputs.get("multi_modal_data")
  215. if multi_modal_data is None or "image" not in multi_modal_data:
  216. return llm_inputs
  217. model_config = ctx.model_config
  218. hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
  219. vision_config = hf_config.vision_config
  220. image_data = multi_modal_data["image"]
  221. if isinstance(image_data, Image.Image):
  222. width, height = image_data.size
  223. image_feature_size = get_llava_onevision_image_feature_size(
  224. hf_config,
  225. input_height=height,
  226. input_width=width,
  227. )
  228. elif is_list_of(image_data, Image.Image):
  229. image_feature_size = [
  230. get_llava_onevision_image_feature_size(
  231. hf_config, input_height=img.height, input_width=img.width
  232. )
  233. for img in image_data
  234. ]
  235. elif isinstance(image_data, torch.Tensor):
  236. num_images, image_feature_size, hidden_size = image_data.shape
  237. elif is_list_of(image_data, torch.Tensor):
  238. image_feature_size = [item.shape[1] for item in image_data]
  239. else:
  240. raise TypeError(f"Invalid image type: {type(image_data)}")
  241. vision_config = hf_config.vision_config
  242. if isinstance(vision_config, CLIPVisionConfig):
  243. return input_processor_for_clip(
  244. model_config,
  245. vision_config,
  246. llm_inputs,
  247. image_token_id=hf_config.image_token_index,
  248. image_feature_size_override=image_feature_size,
  249. )
  250. elif isinstance(vision_config, SiglipVisionConfig):
  251. return input_processor_for_siglip(
  252. model_config,
  253. vision_config,
  254. llm_inputs,
  255. image_token_id=hf_config.image_token_index,
  256. image_feature_size_override=image_feature_size,
  257. )
  258. msg = f"Unsupported vision config: {type(vision_config)}"
  259. raise NotImplementedError(msg)
  260. def input_processor_when_multimodal_input_video(
  261. ctx: InputContext, llm_inputs: LLMInputs
  262. ):
  263. multi_modal_data = llm_inputs.get("multi_modal_data")
  264. if multi_modal_data is None or "video" not in multi_modal_data:
  265. return llm_inputs
  266. video_data = multi_modal_data["video"]
  267. model_config = ctx.model_config
  268. hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
  269. vision_config = hf_config.vision_config
  270. if isinstance(video_data, np.ndarray):
  271. # Supports both CLIP and Siglip
  272. num_frames = video_data.shape[0]
  273. video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
  274. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  275. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  276. tokenizer,
  277. llm_inputs.get("prompt"),
  278. llm_inputs["prompt_token_ids"],
  279. placeholder_token_id=hf_config.video_token_index,
  280. repeat_count=video_feature_size,
  281. )
  282. return LLMInputs(
  283. prompt_token_ids=new_token_ids,
  284. prompt=new_prompt,
  285. multi_modal_data=multi_modal_data,
  286. )
  287. elif is_list_of(video_data, np.ndarray):
  288. raise NotImplementedError("Processing multiple videos is not supported")
  289. msg = f"Unsupported vision config: {type(vision_config)}"
  290. raise NotImplementedError(msg)
  291. def input_processor_for_llava_onevision(
  292. ctx: InputContext, llm_inputs: LLMInputs
  293. ):
  294. multi_modal_data = llm_inputs.get("multi_modal_data")
  295. if multi_modal_data is None or (
  296. "video" not in multi_modal_data and "image" not in multi_modal_data
  297. ):
  298. return llm_inputs
  299. if "image" in multi_modal_data:
  300. return input_processor_when_multimodal_input_image(ctx, llm_inputs)
  301. if "video" in multi_modal_data:
  302. return input_processor_when_multimodal_input_video(ctx, llm_inputs)
  303. msg = "Unsupported multi data type"
  304. raise NotImplementedError(msg)
  305. def _init_vision_tower(hf_config: LlavaOnevisionConfig):
  306. vision_config = hf_config.vision_config
  307. # Initialize the vision tower only up to the required feature layer
  308. vision_feature_layer = hf_config.vision_feature_layer
  309. if vision_feature_layer < 0:
  310. num_hidden_layers = (
  311. hf_config.vision_config.num_hidden_layers + vision_feature_layer + 1
  312. )
  313. else:
  314. num_hidden_layers = vision_feature_layer + 1
  315. if isinstance(vision_config, CLIPVisionConfig):
  316. return CLIPVisionModel(
  317. vision_config,
  318. num_hidden_layers_override=num_hidden_layers,
  319. )
  320. elif isinstance(vision_config, SiglipVisionConfig):
  321. return SiglipVisionModel(
  322. vision_config,
  323. num_hidden_layers_override=num_hidden_layers,
  324. )
  325. msg = f"Unsupported vision config: {type(vision_config)}"
  326. raise NotImplementedError(msg)
  327. class LlavaOnevisionMultiModalProjector(nn.Module):
  328. def __init__(self, config: LlavaOnevisionConfig):
  329. super().__init__()
  330. self.linear_1 = nn.Linear(
  331. config.vision_config.hidden_size,
  332. config.text_config.hidden_size,
  333. bias=True,
  334. )
  335. self.act = get_act_fn(config.projector_hidden_act)
  336. self.linear_2 = nn.Linear(
  337. config.text_config.hidden_size,
  338. config.text_config.hidden_size,
  339. bias=True,
  340. )
  341. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  342. hidden_states = self.linear_1(image_features)
  343. hidden_states = self.act(hidden_states)
  344. hidden_states = self.linear_2(hidden_states)
  345. return hidden_states
  346. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  347. @MULTIMODAL_REGISTRY.register_input_mapper("video")
  348. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  349. "image", get_max_llava_onevision_image_tokens
  350. )
  351. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  352. "video", get_max_llava_onevision_video_tokens
  353. )
  354. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
  355. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
  356. class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
  357. def __init__(
  358. self,
  359. config: LlavaOnevisionConfig,
  360. multimodal_config: MultiModalConfig,
  361. cache_config: Optional[CacheConfig] = None,
  362. quant_config: Optional[QuantizationConfig] = None,
  363. ) -> None:
  364. super().__init__()
  365. self.config = config
  366. self.multimodal_config = multimodal_config
  367. # Initialize the vision tower only up to the required feature layer
  368. self.vision_tower = _init_vision_tower(config)
  369. self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
  370. self.language_model = init_aphrodite_registered_model(
  371. config.text_config, cache_config, quant_config
  372. )
  373. self.image_newline = nn.Parameter(
  374. torch.empty(config.text_config.hidden_size)
  375. )
  376. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  377. expected_dims = (2,)
  378. def _validate_shape(d: torch.Tensor):
  379. actual_dims = tuple(d.shape)
  380. if actual_dims != expected_dims:
  381. expected_expr = str(expected_dims)
  382. raise ValueError(
  383. f"The expected shape of image sizes per image per batch "
  384. f"is {expected_expr}. You supplied {tuple(d.shape)}."
  385. )
  386. for d in data:
  387. _validate_shape(d)
  388. return data
  389. def _validate_image_pixel_values(
  390. self, data: Union[torch.Tensor, List[torch.Tensor]]
  391. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  392. h = w = self.config.vision_config.image_size
  393. expected_dims = (3, h, w)
  394. def _validate_shape(d: torch.Tensor):
  395. actual_dims = tuple(d.shape[1:])
  396. if actual_dims != expected_dims:
  397. expected_expr = ("num_patches", *map(str, expected_dims))
  398. raise ValueError(
  399. "The expected shape of pixel values per image per batch "
  400. f"is {expected_expr}. You supplied {tuple(d.shape)}."
  401. )
  402. for d in data:
  403. _validate_shape(d)
  404. return data
  405. def _parse_and_validate_image_input(
  406. self, **kwargs: object
  407. ) -> Optional[LlavaOnevisionImageInputs]:
  408. pixel_values = kwargs.pop("pixel_values", None)
  409. image_sizes = kwargs.pop("image_sizes", None)
  410. image_embeds = kwargs.pop("image_embeds", None)
  411. if pixel_values is None and image_embeds is None:
  412. return None
  413. if pixel_values is not None:
  414. if not isinstance(pixel_values, (torch.Tensor, list)):
  415. raise ValueError(
  416. "Incorrect type of pixel values. "
  417. f"Got type: {type(pixel_values)}"
  418. )
  419. if not isinstance(image_sizes, (torch.Tensor, list)):
  420. raise ValueError(
  421. "Incorrect type of image sizes. "
  422. f"Got type: {type(image_sizes)}"
  423. )
  424. return LlavaOnevisionImagePixelInputs(
  425. type="pixel_values",
  426. data=self._validate_image_pixel_values(
  427. flatten_bn(pixel_values)
  428. ),
  429. image_sizes=self._validate_image_sizes(
  430. flatten_bn(image_sizes, concat=True)
  431. ),
  432. )
  433. if image_embeds is not None:
  434. if not isinstance(image_embeds, torch.Tensor):
  435. raise ValueError(
  436. "Incorrect type of image embeds. "
  437. f"Got type: {type(image_embeds)}"
  438. )
  439. return LlavaOnevisionImageEmbeddingInputs(
  440. type="image_embeds",
  441. data=flatten_bn(image_embeds),
  442. )
  443. raise AssertionError("This line should be unreachable.")
  444. def _validate_video_pixel_values(
  445. self, data: Union[torch.Tensor, List[torch.Tensor]]
  446. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  447. h = w = self.config.vision_config.image_size
  448. expected_dims = (3, h, w)
  449. def _validate_shape(d: torch.Tensor):
  450. actual_dims = tuple(d.shape[2:])
  451. if actual_dims != expected_dims:
  452. expected_expr = ("num_frames", *map(str, expected_dims))
  453. raise ValueError(
  454. "The expected shape of pixel values in each video frame "
  455. f"is {expected_expr}. You supplied {tuple(d.shape)}."
  456. )
  457. for d in data:
  458. _validate_shape(d)
  459. return data
  460. def _parse_and_validate_video_input(
  461. self, **kwargs: object
  462. ) -> Optional[LlavaOnevisionVideoPixelInputs]:
  463. """
  464. A legal video input should have the following dimensions:
  465. {
  466. "pixel_values_videos" :
  467. List[b, Tensor(nb_frames, nb_channels, height, width)]
  468. }
  469. """
  470. pixel_values = kwargs.pop("pixel_values_videos", None)
  471. if pixel_values is None:
  472. return None
  473. if not (
  474. is_list_of(pixel_values, (torch.Tensor)) # different shape videos
  475. or isinstance(pixel_values, torch.Tensor)
  476. ): # same shape videos
  477. raise ValueError(
  478. "Incorrect type of pixel values. "
  479. f"Got type: {type(pixel_values)}"
  480. )
  481. return LlavaOnevisionVideoPixelInputs(
  482. type="pixel_values_videos",
  483. data=pixel_values,
  484. )
  485. def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
  486. modalities = {}
  487. if "pixel_values" in kwargs:
  488. modalities["images"] = self._parse_and_validate_image_input(
  489. **kwargs
  490. )
  491. if "pixel_values_videos" in kwargs:
  492. modalities["videos"] = self._parse_and_validate_video_input(
  493. **kwargs
  494. )
  495. return modalities
  496. def _select_image_features(
  497. self, image_features: torch.Tensor, *, strategy: str
  498. ) -> torch.Tensor:
  499. if strategy == "default":
  500. return image_features[:, 1:]
  501. elif strategy == "full":
  502. return image_features
  503. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  504. def _image_pixels_to_features(
  505. self,
  506. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  507. pixel_values: torch.Tensor,
  508. ) -> torch.Tensor:
  509. # NOTE: we skip the step to select the vision feature layer since
  510. # this is already done inside the vision tower
  511. image_features = vision_tower(pixel_values)
  512. return self._select_image_features(
  513. image_features,
  514. strategy=self.config.vision_feature_select_strategy,
  515. )
  516. # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
  517. def _merge_image_patch_embeddings(
  518. self,
  519. image_size: torch.Tensor,
  520. patch_embeddings: torch.Tensor,
  521. *,
  522. image_newline=None,
  523. vision_aspect_ratio="anyres_max_9",
  524. strategy: str,
  525. ) -> torch.Tensor:
  526. if strategy == "flat":
  527. return patch_embeddings.flatten(0, 1)
  528. if strategy.startswith("spatial"):
  529. height = width = (
  530. self.config.vision_config.image_size
  531. // self.config.vision_config.patch_size
  532. )
  533. base_patch_embeds = patch_embeddings[0]
  534. if height * width != base_patch_embeds.shape[0]:
  535. raise ValueError(
  536. "The number of patches is not consistent with the "
  537. "image size."
  538. )
  539. if patch_embeddings.shape[0] > 1:
  540. other_patch_embeds = patch_embeddings[1:]
  541. # Move to CPU to avoid floating-point errors
  542. orig_height, orig_width = image_size.tolist()
  543. # image_aspect_ratio == "anyres"
  544. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  545. (orig_height, orig_width),
  546. self.config.image_grid_pinpoints,
  547. self.config.vision_config.image_size,
  548. )
  549. num_patches = num_patch_height * num_patch_width
  550. # Image patches might be padded for batch processing
  551. other_patch_embeds = other_patch_embeds[:num_patches].view(
  552. num_patch_height, num_patch_width, height, width, -1
  553. )
  554. if "unpad" in strategy:
  555. other_patch_embeds = (
  556. other_patch_embeds.permute(4, 0, 2, 1, 3)
  557. .contiguous()
  558. .flatten(1, 2)
  559. .flatten(2, 3)
  560. )
  561. other_patch_embeds = unpad_image(
  562. other_patch_embeds, (orig_height, orig_width)
  563. )
  564. max_num_patches = int(
  565. vision_aspect_ratio.removeprefix("anyres_max_")
  566. )
  567. channels, curr_height, curr_width = other_patch_embeds.shape
  568. ratio = math.sqrt(
  569. curr_height * curr_width / (max_num_patches * height**2)
  570. )
  571. if ratio > 1.1:
  572. other_patch_embeds = other_patch_embeds[None]
  573. other_patch_embeds = nn.functional.interpolate(
  574. other_patch_embeds,
  575. [
  576. int(curr_height // ratio),
  577. int(curr_width // ratio),
  578. ],
  579. mode="bilinear",
  580. )[0]
  581. if image_newline is not None:
  582. other_patch_embeds = torch.cat(
  583. (
  584. other_patch_embeds,
  585. image_newline[:, None, None]
  586. .expand(*other_patch_embeds.shape[:-1], 1)
  587. .to(other_patch_embeds.device),
  588. ),
  589. dim=-1,
  590. )
  591. other_patch_embeds = other_patch_embeds.flatten(
  592. 1, 2
  593. ).transpose(0, 1)
  594. else:
  595. other_patch_embeds = (
  596. other_patch_embeds.permute(0, 2, 1, 3, 4)
  597. .contiguous()
  598. .flatten(0, 3)
  599. )
  600. merged_patch_embeddings = torch.cat(
  601. (base_patch_embeds, other_patch_embeds), dim=0
  602. )
  603. else:
  604. if "unpad" in strategy:
  605. merged_patch_embeddings = torch.cat(
  606. (
  607. base_patch_embeds,
  608. self.image_newline[None].to(
  609. base_patch_embeds.device
  610. ),
  611. ),
  612. dim=0,
  613. )
  614. else:
  615. merged_patch_embeddings = base_patch_embeds
  616. return merged_patch_embeddings
  617. raise ValueError(f"Unexpected patch merge strategy: {strategy}")
  618. def _process_image_pixels(
  619. self,
  620. inputs: LlavaOnevisionImagePixelInputs,
  621. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  622. assert self.vision_tower is not None
  623. pixel_values = inputs["data"]
  624. if isinstance(pixel_values, torch.Tensor):
  625. b, num_patches, c, h, w = pixel_values.shape
  626. stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
  627. stacked_image_features = self._image_pixels_to_features(
  628. self.vision_tower, stacked_pixel_values
  629. )
  630. stacked_patch_embeddings = self.multi_modal_projector(
  631. stacked_image_features
  632. )
  633. return stacked_patch_embeddings.view(
  634. b, num_patches, *stacked_patch_embeddings.shape[1:]
  635. )
  636. num_patches_per_batch = [v.shape[0] for v in pixel_values]
  637. stacked_pixel_values = torch.cat(pixel_values)
  638. stacked_image_features = self._image_pixels_to_features(
  639. self.vision_tower, stacked_pixel_values
  640. )
  641. return [
  642. self.multi_modal_projector(image_features)
  643. for image_features in torch.split(
  644. stacked_image_features, num_patches_per_batch
  645. )
  646. ]
  647. def _process_image_input(
  648. self,
  649. image_input: LlavaOnevisionImageInputs,
  650. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  651. if image_input["type"] == "image_embeds":
  652. return [image_input["data"]]
  653. patch_embeddings = self._process_image_pixels(image_input)
  654. image_sizes = image_input.get("image_sizes")
  655. if image_sizes is None:
  656. batch_size = len(image_input["data"])
  657. vision_config = self.config.vision_config
  658. default_height = default_width = vision_config.image_size
  659. image_sizes = torch.as_tensor(
  660. [[default_height, default_width] for _ in range(batch_size)]
  661. )
  662. return [
  663. self._merge_image_patch_embeddings(
  664. image_sizes[i],
  665. patch_features_batch,
  666. image_newline=self.image_newline,
  667. strategy="spatial_unpad",
  668. )
  669. for i, patch_features_batch in enumerate(patch_embeddings)
  670. ]
  671. def _video_pixels_to_features(
  672. self,
  673. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  674. pixel_values: torch.Tensor,
  675. ) -> torch.Tensor:
  676. # NOTE: we skip the step to select the vision feature layer since
  677. # this is already done inside the vision tower
  678. b, num_videos, frames, c, h, w = pixel_values.shape
  679. assert num_videos == _MAX_NUM_VIDEOS
  680. pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w)
  681. video_features = vision_tower(pixel_values)
  682. video_features = self._select_image_features(
  683. video_features,
  684. strategy=self.config.vision_feature_select_strategy,
  685. )
  686. video_features = self.multi_modal_projector(video_features)
  687. video_features = self.apply_pooling(video_features)
  688. video_features = video_features.reshape(
  689. b, frames * video_features.shape[1], -1
  690. )
  691. image_newline = (
  692. self.image_newline[None, None, :]
  693. .repeat(b, 1, 1)
  694. .to(video_features.device)
  695. )
  696. video_features = torch.cat((video_features, image_newline), dim=1)
  697. video_features = video_features.flatten(0, 1)
  698. return video_features
  699. def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
  700. assert self.vision_tower is not None
  701. video_pixels = inputs["data"]
  702. # TODO: support multiple videos per input
  703. if isinstance(video_pixels, torch.Tensor):
  704. stacked_embeddings = self._video_pixels_to_features(
  705. self.vision_tower, video_pixels
  706. )
  707. return stacked_embeddings
  708. else:
  709. raise ValueError(
  710. f"Unsupported type of video input {type(video_pixels)}"
  711. )
  712. def apply_pooling(self, image_features, stride=2):
  713. vision_config = self.config.vision_config
  714. height = width = vision_config.image_size // vision_config.patch_size
  715. batch_frames, _, dim = image_features.shape
  716. image_features = image_features.view(batch_frames, height, width, -1)
  717. image_features = image_features.permute(0, 3, 1, 2)
  718. # TODO support other pooling types config
  719. height, width = image_features.shape[2:]
  720. scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
  721. image_feature = nn.functional.interpolate(
  722. image_features, size=scaled_shape, mode="bilinear"
  723. )
  724. image_feature = image_feature.permute(0, 2, 3, 1)
  725. image_feature = image_feature.view(batch_frames, -1, dim)
  726. return image_feature
  727. def forward(
  728. self,
  729. input_ids: torch.Tensor,
  730. positions: torch.Tensor,
  731. kv_caches: List[torch.Tensor],
  732. attn_metadata: AttentionMetadata,
  733. intermediate_tensors: Optional[IntermediateTensors] = None,
  734. **kwargs: object,
  735. ) -> SamplerOutput:
  736. """Run forward pass for LlaVA-Onevision.
  737. Args:
  738. input_ids: Flattened (concatenated) input_ids corresponding to a
  739. batch.
  740. pixel_values_videos: Pixels in each frames for each input videos.
  741. """
  742. modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
  743. # merge video embeddings into input embeddings
  744. if modalities:
  745. inputs_embeds = self.language_model.model.get_input_embeddings(
  746. input_ids
  747. )
  748. if "images" in modalities:
  749. image_input = modalities["images"]
  750. vision_embeddings = self._process_image_input(image_input)
  751. inputs_embeds = merge_multimodal_embeddings(
  752. input_ids,
  753. inputs_embeds,
  754. vision_embeddings,
  755. self.config.image_token_index,
  756. )
  757. if "videos" in modalities:
  758. video_input = modalities["videos"]
  759. video_embeddings = self._process_video_pixels(video_input)
  760. inputs_embeds = merge_multimodal_embeddings(
  761. input_ids,
  762. inputs_embeds,
  763. video_embeddings,
  764. self.config.video_token_index,
  765. )
  766. input_ids = None
  767. else:
  768. inputs_embeds = None
  769. hidden_states = self.language_model.model(
  770. input_ids,
  771. positions,
  772. kv_caches,
  773. attn_metadata,
  774. None,
  775. inputs_embeds=inputs_embeds,
  776. )
  777. return hidden_states
  778. def compute_logits(
  779. self,
  780. hidden_states: torch.Tensor,
  781. sampling_metadata: SamplingMetadata,
  782. ) -> Optional[torch.Tensor]:
  783. return self.language_model.compute_logits(
  784. hidden_states, sampling_metadata
  785. )
  786. def sample(
  787. self,
  788. logits: torch.Tensor,
  789. sampling_metadata: SamplingMetadata,
  790. ) -> Optional[SamplerOutput]:
  791. return self.language_model.sample(logits, sampling_metadata)
  792. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  793. # prepare weight iterators for components
  794. weights_group = group_weights_with_prefix(weights)
  795. # load vision encoder
  796. self.vision_tower.load_weights(weights_group["vision_tower"])
  797. # load mlp projector
  798. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  799. for name, loaded_weight in weights_group["multi_modal_projector"]:
  800. param = mlp_params_dict[name]
  801. weight_loader = getattr(
  802. param, "weight_loader", default_weight_loader
  803. )
  804. weight_loader(param, loaded_weight)
  805. # load llm backbone
  806. self.language_model.load_weights(weights_group["language_model"])