llava_next.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. import itertools
  2. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  3. TypedDict, Union)
  4. import torch
  5. import torch.nn as nn
  6. from PIL import Image
  7. from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
  8. from transformers.models.llava_next.modeling_llava_next import (
  9. get_anyres_image_grid_shape, unpad_image)
  10. from typing_extensions import NotRequired
  11. from aphrodite.attention import AttentionMetadata
  12. from aphrodite.common.config import CacheConfig, MultiModalConfig
  13. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  14. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  15. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.quantization.base_config import QuantizationConfig
  19. from .clip import (CLIPVisionModel, dummy_image_for_clip,
  20. dummy_seq_data_for_clip, get_clip_image_feature_size,
  21. get_clip_patch_grid_length, input_processor_for_clip)
  22. from .interfaces import SupportsMultiModal
  23. from .llava import LlavaMultiModalProjector
  24. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  25. dummy_seq_data_for_siglip, get_siglip_image_feature_size,
  26. get_siglip_patch_grid_length, input_processor_for_siglip)
  27. from .utils import (filter_weights, init_aphrodite_registered_model,
  28. merge_multimodal_embeddings)
  29. _KEYS_TO_MODIFY_MAPPING = {
  30. "language_model.lm_head": "lm_head",
  31. "language_model.model": "language_model",
  32. }
  33. # Result in the max possible feature size (2x2 grid of 336x336px tiles)
  34. MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
  35. class LlavaNextImagePixelInputs(TypedDict):
  36. type: Literal["pixel_values"]
  37. data: Union[torch.Tensor, List[torch.Tensor]]
  38. """
  39. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  40. Note that `num_patches` may be different for each batch, in which case
  41. the data is passed as a list instead of a batched tensor.
  42. """
  43. image_sizes: NotRequired[torch.Tensor]
  44. """
  45. Shape: `(batch_size, 2)`
  46. This should be in `(height, width)` format.
  47. """
  48. class LlavaNextImageEmbeddingInputs(TypedDict):
  49. type: Literal["image_embeds"]
  50. data: torch.Tensor
  51. """Shape: `(batch_size, image_feature_size, hidden_size)`
  52. `hidden_size` must match the hidden size of language model backbone.
  53. """
  54. LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
  55. LlavaNextImageEmbeddingInputs]
  56. # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
  57. def _get_llava_next_num_unpadded_features(
  58. original_height: int,
  59. original_width: int,
  60. npatches: int,
  61. num_patch_height: int,
  62. num_patch_width: int,
  63. ) -> Tuple[int, int]:
  64. current_height = npatches * num_patch_height
  65. current_width = npatches * num_patch_width
  66. aspect_ratio = original_width / original_height
  67. current_aspect_ratio = current_width / current_height
  68. if aspect_ratio > current_aspect_ratio:
  69. new_height = (original_height * current_width) // original_width
  70. padding = (current_height - new_height) // 2
  71. current_height -= padding * 2
  72. else:
  73. new_width = (original_width * current_height) // original_height
  74. padding = (current_width - new_width) // 2
  75. current_width -= padding * 2
  76. unpadded_features = current_height * current_width
  77. newline_features = current_height
  78. return (unpadded_features, newline_features)
  79. # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
  80. def get_llava_next_image_feature_size(
  81. hf_config: LlavaNextConfig,
  82. *,
  83. input_height: int,
  84. input_width: int,
  85. ) -> int:
  86. vision_config = hf_config.vision_config
  87. if isinstance(vision_config, CLIPVisionConfig):
  88. num_patches = get_clip_patch_grid_length(
  89. image_size=vision_config.image_size,
  90. patch_size=vision_config.patch_size,
  91. )
  92. base_feature_size = get_clip_image_feature_size(vision_config)
  93. elif isinstance(vision_config, SiglipVisionConfig):
  94. num_patches = get_siglip_patch_grid_length(
  95. image_size=vision_config.image_size,
  96. patch_size=vision_config.patch_size,
  97. )
  98. base_feature_size = get_siglip_image_feature_size(vision_config)
  99. else:
  100. msg = f"Unsupported vision config: {type(vision_config)}"
  101. raise NotImplementedError(msg)
  102. strategy = hf_config.vision_feature_select_strategy
  103. if strategy == "default":
  104. base_feature_size -= 1
  105. elif strategy == "full":
  106. pass
  107. else:
  108. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  109. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  110. image_size=(input_height, input_width),
  111. grid_pinpoints=hf_config.image_grid_pinpoints,
  112. patch_size=vision_config.image_size,
  113. )
  114. (
  115. unpadded_feature_size,
  116. newline_feature_size,
  117. ) = _get_llava_next_num_unpadded_features(input_height, input_width,
  118. num_patches, num_patch_height,
  119. num_patch_width)
  120. return unpadded_feature_size + newline_feature_size + base_feature_size
  121. def get_max_llava_next_image_tokens(ctx: InputContext):
  122. return get_llava_next_image_feature_size(
  123. ctx.get_hf_config(LlavaNextConfig),
  124. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  125. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  126. )
  127. def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
  128. mm_counts: Mapping[str, int]):
  129. hf_config = ctx.get_hf_config(LlavaNextConfig)
  130. vision_config = hf_config.vision_config
  131. num_images = mm_counts["image"]
  132. image_feature_size = get_max_llava_next_image_tokens(ctx)
  133. if isinstance(vision_config, CLIPVisionConfig):
  134. seq_data = dummy_seq_data_for_clip(
  135. vision_config,
  136. seq_len,
  137. num_images,
  138. image_token_id=hf_config.image_token_index,
  139. image_feature_size_override=image_feature_size,
  140. )
  141. mm_data = dummy_image_for_clip(
  142. vision_config,
  143. num_images,
  144. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  145. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  146. )
  147. return seq_data, mm_data
  148. elif isinstance(vision_config, SiglipVisionConfig):
  149. seq_data = dummy_seq_data_for_siglip(
  150. vision_config,
  151. seq_len,
  152. num_images,
  153. image_token_id=hf_config.image_token_index,
  154. image_feature_size_override=image_feature_size,
  155. )
  156. mm_data = dummy_image_for_siglip(
  157. vision_config,
  158. num_images,
  159. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  160. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  161. )
  162. return seq_data, mm_data
  163. msg = f"Unsupported vision config: {type(vision_config)}"
  164. raise NotImplementedError(msg)
  165. def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
  166. multi_modal_data = llm_inputs.get("multi_modal_data")
  167. if multi_modal_data is None or "image" not in multi_modal_data:
  168. return llm_inputs
  169. model_config = ctx.model_config
  170. hf_config = ctx.get_hf_config(LlavaNextConfig)
  171. vision_config = hf_config.vision_config
  172. image_data = multi_modal_data["image"]
  173. if isinstance(image_data, Image.Image):
  174. width, height = image_data.size
  175. image_feature_size = get_llava_next_image_feature_size(
  176. hf_config,
  177. input_height=height,
  178. input_width=width,
  179. )
  180. elif isinstance(image_data, torch.Tensor):
  181. image_feature_size = image_data.shape[0]
  182. else:
  183. raise TypeError(f"Invalid image type: {type(image_data)}")
  184. vision_config = hf_config.vision_config
  185. if isinstance(vision_config, CLIPVisionConfig):
  186. return input_processor_for_clip(
  187. model_config,
  188. vision_config,
  189. llm_inputs,
  190. image_token_id=hf_config.image_token_index,
  191. image_feature_size_override=image_feature_size,
  192. )
  193. elif isinstance(vision_config, SiglipVisionConfig):
  194. return input_processor_for_siglip(
  195. model_config,
  196. vision_config,
  197. llm_inputs,
  198. image_token_id=hf_config.image_token_index,
  199. image_feature_size_override=image_feature_size,
  200. )
  201. msg = f"Unsupported vision config: {type(vision_config)}"
  202. raise NotImplementedError(msg)
  203. def _init_vision_tower(hf_config: LlavaNextConfig):
  204. vision_config = hf_config.vision_config
  205. # Initialize the vision tower only up to the required feature layer
  206. vision_feature_layer = hf_config.vision_feature_layer
  207. if vision_feature_layer < 0:
  208. num_hidden_layers = hf_config.vision_config.num_hidden_layers \
  209. + vision_feature_layer + 1
  210. else:
  211. num_hidden_layers = vision_feature_layer + 1
  212. if isinstance(vision_config, CLIPVisionConfig):
  213. return CLIPVisionModel(
  214. vision_config,
  215. num_hidden_layers_override=num_hidden_layers,
  216. )
  217. elif isinstance(vision_config, SiglipVisionConfig):
  218. return SiglipVisionModel(
  219. vision_config,
  220. num_hidden_layers_override=num_hidden_layers,
  221. )
  222. msg = f"Unsupported vision config: {type(vision_config)}"
  223. raise NotImplementedError(msg)
  224. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  225. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
  226. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
  227. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
  228. class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
  229. def __init__(self,
  230. config: LlavaNextConfig,
  231. multimodal_config: MultiModalConfig,
  232. cache_config: Optional[CacheConfig] = None,
  233. quant_config: Optional[QuantizationConfig] = None) -> None:
  234. super().__init__()
  235. self.config = config
  236. self.multimodal_config = multimodal_config
  237. # TODO: Optionally initializes this for supporting embeddings.
  238. self.vision_tower = _init_vision_tower(config)
  239. self.multi_modal_projector = LlavaMultiModalProjector(
  240. vision_hidden_size=config.vision_config.hidden_size,
  241. text_hidden_size=config.text_config.hidden_size,
  242. projector_hidden_act=config.projector_hidden_act)
  243. self.language_model = init_aphrodite_registered_model(
  244. config.text_config, cache_config, quant_config)
  245. self.image_newline = nn.Parameter(
  246. torch.empty(config.text_config.hidden_size))
  247. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  248. if list(data.shape[1:]) != [2]:
  249. raise ValueError(
  250. f"The expected image sizes shape is batch dimension plus "
  251. f"{[2]}. You supplied {data.shape}.")
  252. return data
  253. def _validate_pixel_values(
  254. self, data: Union[torch.Tensor, List[torch.Tensor]]
  255. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  256. h = w = self.config.vision_config.image_size
  257. expected_dims = (3, h, w)
  258. def _validate_shape(d: torch.Tensor):
  259. actual_dims = tuple(d.shape[1:])
  260. if actual_dims != expected_dims:
  261. expected_expr = ("num_patches", *map(str, expected_dims))
  262. raise ValueError(
  263. "The expected shape of pixel values in each batch element "
  264. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  265. for d in data:
  266. _validate_shape(d)
  267. return data
  268. def _parse_and_validate_image_input(
  269. self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
  270. pixel_values = kwargs.pop("pixel_values", None)
  271. image_sizes = kwargs.pop("image_sizes", None)
  272. image_embeds = kwargs.pop("image_embeds", None)
  273. if pixel_values is None and image_embeds is None:
  274. return None
  275. if pixel_values is not None:
  276. if not isinstance(pixel_values, (torch.Tensor, list)):
  277. raise ValueError("Incorrect type of pixel values. "
  278. f"Got type: {type(pixel_values)}")
  279. if not isinstance(image_sizes, torch.Tensor):
  280. raise ValueError("Incorrect type of image sizes. "
  281. f"Got type: {type(image_sizes)}")
  282. return LlavaNextImagePixelInputs(
  283. type="pixel_values",
  284. data=self._validate_pixel_values(pixel_values),
  285. image_sizes=self._validate_image_sizes(image_sizes),
  286. )
  287. if image_embeds is not None:
  288. if not isinstance(image_embeds, torch.Tensor):
  289. raise ValueError("Incorrect type of image embeds. "
  290. f"Got type: {type(image_embeds)}")
  291. return LlavaNextImageEmbeddingInputs(
  292. type="image_embeds",
  293. data=image_embeds,
  294. )
  295. raise AssertionError("This line should be unreachable.")
  296. def _select_image_features(self, image_features: torch.Tensor, *,
  297. strategy: str) -> torch.Tensor:
  298. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  299. if strategy == "default":
  300. return image_features[:, 1:]
  301. elif strategy == "full":
  302. return image_features
  303. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  304. def _image_pixels_to_features(
  305. self,
  306. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  307. pixel_values: torch.Tensor,
  308. ) -> torch.Tensor:
  309. # NOTE: we skip the step to select the vision feature layer since
  310. # this is already done inside the vision tower
  311. image_features = vision_tower(pixel_values)
  312. return self._select_image_features(
  313. image_features,
  314. strategy=self.config.vision_feature_select_strategy,
  315. )
  316. # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
  317. def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
  318. patch_embeddings: torch.Tensor, *,
  319. strategy: str) -> torch.Tensor:
  320. if strategy == "flat":
  321. return patch_embeddings.flatten(0, 1)
  322. if strategy.startswith("spatial"):
  323. height = width = self.config.vision_config.image_size \
  324. // self.config.vision_config.patch_size
  325. base_patch_embeds = patch_embeddings[0]
  326. if height * width != base_patch_embeds.shape[0]:
  327. raise ValueError(
  328. "The number of patches is not consistent with the "
  329. "image size.")
  330. if patch_embeddings.shape[0] > 1:
  331. other_patch_embeds = patch_embeddings[1:]
  332. # Move to CPU to avoid floating-point errors
  333. orig_height, orig_width = image_size.tolist()
  334. # image_aspect_ratio == "anyres"
  335. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  336. (orig_height, orig_width),
  337. self.config.image_grid_pinpoints,
  338. self.config.vision_config.image_size,
  339. )
  340. other_patch_embeds = other_patch_embeds \
  341. .view(num_patch_height, num_patch_width, height, width, -1)
  342. if "unpad" in strategy:
  343. other_patch_embeds = other_patch_embeds \
  344. .permute(4, 0, 2, 1, 3).contiguous() \
  345. .flatten(1, 2).flatten(2, 3)
  346. other_patch_embeds = unpad_image(other_patch_embeds,
  347. (orig_height, orig_width))
  348. other_patch_embeds = torch.cat((
  349. other_patch_embeds,
  350. self.image_newline[:, None, None] \
  351. .expand(*other_patch_embeds.shape[:-1], 1) \
  352. .to(other_patch_embeds.device),
  353. ), dim=-1)
  354. other_patch_embeds = other_patch_embeds \
  355. .flatten(1, 2).transpose(0, 1)
  356. else:
  357. other_patch_embeds = other_patch_embeds \
  358. .permute(0, 2, 1, 3, 4).contiguous() \
  359. .flatten(0, 3)
  360. merged_patch_embeddings = torch.cat(
  361. (base_patch_embeds, other_patch_embeds), dim=0)
  362. else:
  363. if "unpad" in strategy:
  364. merged_patch_embeddings = torch.cat(
  365. (base_patch_embeds,
  366. self.image_newline[None] \
  367. .to(base_patch_embeds.device)
  368. ), dim=0)
  369. else:
  370. merged_patch_embeddings = base_patch_embeds
  371. return merged_patch_embeddings
  372. raise ValueError(f"Unexpected patch merge strategy: {strategy}")
  373. def _process_image_pixels(
  374. self,
  375. inputs: LlavaNextImagePixelInputs,
  376. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  377. assert self.vision_tower is not None
  378. pixel_values = inputs["data"]
  379. if isinstance(pixel_values, torch.Tensor):
  380. b, num_patches, c, h, w = pixel_values.shape
  381. stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
  382. stacked_image_features = self._image_pixels_to_features(
  383. self.vision_tower, stacked_pixel_values)
  384. stacked_patch_embeddings = self.multi_modal_projector(
  385. stacked_image_features)
  386. return stacked_patch_embeddings.view(
  387. b, num_patches, *stacked_patch_embeddings.shape[1:])
  388. num_patches_per_batch = [v.shape[0] for v in pixel_values]
  389. stacked_pixel_values = torch.cat(pixel_values)
  390. stacked_image_features = self._image_pixels_to_features(
  391. self.vision_tower, stacked_pixel_values)
  392. return [
  393. self.multi_modal_projector(image_features) for image_features in
  394. torch.split(stacked_image_features, num_patches_per_batch)
  395. ]
  396. def _process_image_input(
  397. self,
  398. image_input: LlavaNextImageInputs,
  399. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  400. if image_input["type"] == "image_embeds":
  401. return [image_input["data"]]
  402. patch_embeddings = self._process_image_pixels(image_input)
  403. image_sizes = image_input.get("image_sizes")
  404. if image_sizes is None:
  405. batch_size = len(image_input["data"])
  406. vision_config = self.config.vision_config
  407. default_height = default_width = vision_config.image_size
  408. image_sizes = torch.as_tensor([[default_height, default_width]
  409. for _ in range(batch_size)])
  410. return [
  411. self._merge_image_patch_embeddings(image_sizes[i],
  412. patch_features_batch,
  413. strategy="spatial_unpad")
  414. for i, patch_features_batch in enumerate(patch_embeddings)
  415. ]
  416. def forward(
  417. self,
  418. input_ids: torch.Tensor,
  419. positions: torch.Tensor,
  420. kv_caches: List[torch.Tensor],
  421. attn_metadata: AttentionMetadata,
  422. intermediate_tensors: Optional[IntermediateTensors] = None,
  423. **kwargs: object,
  424. ) -> SamplerOutput:
  425. """Run forward pass for LlaVA-NeXT.
  426. One key thing to understand is the `input_ids` already accounts for the
  427. positions of the to-be-inserted image embeddings.
  428. Concretely, consider a text prompt:
  429. `"A chat between a curious human and an artificial intelligence
  430. assistant. The assistant gives helpful, detailed, and polite answers to
  431. the human's questions.
  432. USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
  433. Tokenizer outputs:
  434. `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
  435. 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
  436. 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
  437. 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
  438. 9047, 13566, 29901]`.
  439. To reserve space in KV cache, we have to insert placeholder tokens
  440. before they are inputted to the model, so the input processor prepends
  441. additional image tokens (denoted as `32000`), resulting in:
  442. `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
  443. 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
  444. 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
  445. 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
  446. 319, 1799, 9047, 13566, 29901]`.
  447. Unlike in LLaVA-1.5, the number of image tokens inputted to the language
  448. model depends on the original size of the input image. Including the
  449. original image token in the input, the required number of image tokens
  450. is given by :func:`get_llava_next_image_feature_size`.
  451. This way, the `positions` and `attn_metadata` are consistent
  452. with the `input_ids`.
  453. Args:
  454. input_ids: Flattened (concatenated) input_ids corresponding to a
  455. batch.
  456. pixel_values: The pixels in each grid patch for each input image.
  457. image_sizes: The original `(height, width)` for each input image.
  458. See also:
  459. :class:`LlavaNextImageInputs`
  460. """
  461. image_input = self._parse_and_validate_image_input(**kwargs)
  462. if image_input is not None:
  463. vision_embeddings = self._process_image_input(image_input)
  464. inputs_embeds = self.language_model.model.get_input_embeddings(
  465. input_ids)
  466. inputs_embeds = merge_multimodal_embeddings(
  467. input_ids, inputs_embeds, vision_embeddings,
  468. self.config.image_token_index)
  469. input_ids = None
  470. else:
  471. inputs_embeds = None
  472. hidden_states = self.language_model.model(input_ids,
  473. positions,
  474. kv_caches,
  475. attn_metadata,
  476. None,
  477. inputs_embeds=inputs_embeds)
  478. return hidden_states
  479. def compute_logits(
  480. self,
  481. hidden_states: torch.Tensor,
  482. sampling_metadata: SamplingMetadata,
  483. ) -> Optional[torch.Tensor]:
  484. return self.language_model.compute_logits(hidden_states,
  485. sampling_metadata)
  486. def sample(
  487. self,
  488. logits: torch.Tensor,
  489. sampling_metadata: SamplingMetadata,
  490. ) -> Optional[SamplerOutput]:
  491. return self.language_model.sample(logits, sampling_metadata)
  492. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  493. # prepare weight iterators for components
  494. vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
  495. weights, 4)
  496. # load vision encoder
  497. vit_weights = filter_weights(vit_weights, "vision_tower")
  498. self.vision_tower.load_weights(vit_weights)
  499. # load mlp projector
  500. mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
  501. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  502. for name, loaded_weight in mlp_weights:
  503. param = mlp_params_dict[name]
  504. weight_loader = getattr(param, "weight_loader",
  505. default_weight_loader)
  506. weight_loader(param, loaded_weight)
  507. # load newline
  508. newline_weights = filter_weights(newline_weights, "image_newline")
  509. for name, loaded_weight in newline_weights:
  510. assert name == ""
  511. param = self.image_newline
  512. weight_loader = getattr(param, "weight_loader",
  513. default_weight_loader)
  514. weight_loader(param, loaded_weight)
  515. # load llm backbone
  516. llm_weights = filter_weights(llm_weights, "language_model")
  517. self.language_model.load_weights(llm_weights)