1
0

llava_next.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. import itertools
  2. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  3. import torch
  4. import torch.nn as nn
  5. from PIL import Image
  6. from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
  7. from transformers.models.llava_next.modeling_llava_next import (
  8. get_anyres_image_grid_shape, unpad_image)
  9. from typing_extensions import NotRequired
  10. from aphrodite.attention import AttentionMetadata
  11. from aphrodite.common.config import CacheConfig, MultiModalConfig
  12. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  13. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  14. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  15. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  16. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  17. from aphrodite.quantization.base_config import QuantizationConfig
  18. from .clip import (CLIPVisionModel, dummy_image_for_clip,
  19. dummy_seq_data_for_clip, get_clip_image_feature_size,
  20. get_clip_patch_grid_length, input_processor_for_clip)
  21. from .interfaces import SupportsVision
  22. from .llava import LlavaMultiModalProjector
  23. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  24. dummy_seq_data_for_siglip, get_siglip_image_feature_size,
  25. get_siglip_patch_grid_length, input_processor_for_siglip)
  26. from .utils import (filter_weights, init_aphrodite_registered_model,
  27. merge_vision_embeddings)
  28. _KEYS_TO_MODIFY_MAPPING = {
  29. "language_model.lm_head": "lm_head",
  30. "language_model.model": "language_model",
  31. }
  32. # Result in the max possible feature size (2x2 grid of 336x336px tiles)
  33. MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
  34. class LlavaNextImagePixelInputs(TypedDict):
  35. type: Literal["pixel_values"]
  36. data: Union[torch.Tensor, List[torch.Tensor]]
  37. """
  38. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  39. Note that `num_patches` may be different for each batch, in which case
  40. the data is passed as a list instead of a batched tensor.
  41. """
  42. image_sizes: NotRequired[torch.Tensor]
  43. """
  44. Shape: `(batch_size, 2)`
  45. This should be in `(height, width)` format.
  46. """
  47. LlavaNextImageInputs = LlavaNextImagePixelInputs
  48. # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
  49. def _get_llava_next_num_unpadded_features(
  50. original_height: int,
  51. original_width: int,
  52. npatches: int,
  53. num_patch_height: int,
  54. num_patch_width: int,
  55. ) -> Tuple[int, int]:
  56. current_height = npatches * num_patch_height
  57. current_width = npatches * num_patch_width
  58. aspect_ratio = original_width / original_height
  59. current_aspect_ratio = current_width / current_height
  60. if aspect_ratio > current_aspect_ratio:
  61. new_height = (original_height * current_width) // original_width
  62. padding = (current_height - new_height) // 2
  63. current_height -= padding * 2
  64. else:
  65. new_width = (original_width * current_height) // original_height
  66. padding = (current_width - new_width) // 2
  67. current_width -= padding * 2
  68. unpadded_features = current_height * current_width
  69. newline_features = current_height
  70. return (unpadded_features, newline_features)
  71. # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
  72. def get_llava_next_image_feature_size(
  73. hf_config: LlavaNextConfig,
  74. *,
  75. input_height: int,
  76. input_width: int,
  77. ) -> int:
  78. vision_config = hf_config.vision_config
  79. if isinstance(vision_config, CLIPVisionConfig):
  80. num_patches = get_clip_patch_grid_length(
  81. image_size=vision_config.image_size,
  82. patch_size=vision_config.patch_size,
  83. )
  84. base_feature_size = get_clip_image_feature_size(vision_config)
  85. elif isinstance(vision_config, SiglipVisionConfig):
  86. num_patches = get_siglip_patch_grid_length(
  87. image_size=vision_config.image_size,
  88. patch_size=vision_config.patch_size,
  89. )
  90. base_feature_size = get_siglip_image_feature_size(vision_config)
  91. else:
  92. msg = f"Unsupported vision config: {type(vision_config)}"
  93. raise NotImplementedError(msg)
  94. strategy = hf_config.vision_feature_select_strategy
  95. if strategy == "default":
  96. base_feature_size -= 1
  97. elif strategy == "full":
  98. pass
  99. else:
  100. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  101. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  102. image_size=(input_height, input_width),
  103. grid_pinpoints=hf_config.image_grid_pinpoints,
  104. patch_size=vision_config.image_size,
  105. )
  106. (
  107. unpadded_feature_size,
  108. newline_feature_size,
  109. ) = _get_llava_next_num_unpadded_features(input_height, input_width,
  110. num_patches, num_patch_height,
  111. num_patch_width)
  112. return unpadded_feature_size + newline_feature_size + base_feature_size
  113. def get_max_llava_next_image_tokens(ctx: InputContext):
  114. return get_llava_next_image_feature_size(
  115. ctx.get_hf_config(LlavaNextConfig),
  116. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  117. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  118. )
  119. def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
  120. hf_config = ctx.get_hf_config(LlavaNextConfig)
  121. vision_config = hf_config.vision_config
  122. image_feature_size = get_max_llava_next_image_tokens(ctx)
  123. if isinstance(vision_config, CLIPVisionConfig):
  124. seq_data = dummy_seq_data_for_clip(
  125. vision_config,
  126. seq_len,
  127. image_token_id=hf_config.image_token_index,
  128. image_feature_size_override=image_feature_size,
  129. )
  130. mm_data = dummy_image_for_clip(
  131. vision_config,
  132. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  133. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  134. )
  135. return seq_data, mm_data
  136. elif isinstance(vision_config, SiglipVisionConfig):
  137. seq_data = dummy_seq_data_for_siglip(
  138. vision_config,
  139. seq_len,
  140. image_token_id=hf_config.image_token_index,
  141. image_feature_size_override=image_feature_size,
  142. )
  143. mm_data = dummy_image_for_siglip(
  144. vision_config,
  145. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  146. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  147. )
  148. return seq_data, mm_data
  149. msg = f"Unsupported vision config: {type(vision_config)}"
  150. raise NotImplementedError(msg)
  151. def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
  152. multi_modal_data = llm_inputs.get("multi_modal_data")
  153. if multi_modal_data is None or "image" not in multi_modal_data:
  154. return llm_inputs
  155. model_config = ctx.model_config
  156. hf_config = ctx.get_hf_config(LlavaNextConfig)
  157. vision_config = hf_config.vision_config
  158. image_data = multi_modal_data["image"]
  159. if isinstance(image_data, Image.Image):
  160. width, height = image_data.size
  161. image_feature_size = get_llava_next_image_feature_size(
  162. hf_config,
  163. input_height=height,
  164. input_width=width,
  165. )
  166. elif isinstance(image_data, torch.Tensor):
  167. raise NotImplementedError("Embeddings input is not supported yet")
  168. else:
  169. raise TypeError(f"Invalid image type: {type(image_data)}")
  170. vision_config = hf_config.vision_config
  171. if isinstance(vision_config, CLIPVisionConfig):
  172. return input_processor_for_clip(
  173. model_config,
  174. vision_config,
  175. llm_inputs,
  176. image_token_id=hf_config.image_token_index,
  177. image_feature_size_override=image_feature_size,
  178. )
  179. elif isinstance(vision_config, SiglipVisionConfig):
  180. return input_processor_for_siglip(
  181. model_config,
  182. vision_config,
  183. llm_inputs,
  184. image_token_id=hf_config.image_token_index,
  185. image_feature_size_override=image_feature_size,
  186. )
  187. msg = f"Unsupported vision config: {type(vision_config)}"
  188. raise NotImplementedError(msg)
  189. def _init_vision_tower(hf_config: LlavaNextConfig):
  190. vision_config = hf_config.vision_config
  191. # Initialize the vision tower only up to the required feature layer
  192. vision_feature_layer = hf_config.vision_feature_layer
  193. if vision_feature_layer < 0:
  194. num_hidden_layers = hf_config.vision_config.num_hidden_layers \
  195. + vision_feature_layer + 1
  196. else:
  197. num_hidden_layers = vision_feature_layer + 1
  198. if isinstance(vision_config, CLIPVisionConfig):
  199. return CLIPVisionModel(
  200. vision_config,
  201. num_hidden_layers_override=num_hidden_layers,
  202. )
  203. elif isinstance(vision_config, SiglipVisionConfig):
  204. return SiglipVisionModel(
  205. vision_config,
  206. num_hidden_layers_override=num_hidden_layers,
  207. )
  208. msg = f"Unsupported vision config: {type(vision_config)}"
  209. raise NotImplementedError(msg)
  210. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  211. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
  212. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
  213. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
  214. class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
  215. def __init__(self,
  216. config: LlavaNextConfig,
  217. multimodal_config: MultiModalConfig,
  218. cache_config: Optional[CacheConfig] = None,
  219. quant_config: Optional[QuantizationConfig] = None) -> None:
  220. super().__init__()
  221. self.config = config
  222. self.multimodal_config = multimodal_config
  223. # TODO: Optionally initializes this for supporting embeddings.
  224. self.vision_tower = _init_vision_tower(config)
  225. self.multi_modal_projector = LlavaMultiModalProjector(
  226. vision_hidden_size=config.vision_config.hidden_size,
  227. text_hidden_size=config.text_config.hidden_size,
  228. projector_hidden_act=config.projector_hidden_act)
  229. self.language_model = init_aphrodite_registered_model(
  230. config.text_config, cache_config, quant_config)
  231. self.image_newline = nn.Parameter(
  232. torch.empty(config.text_config.hidden_size))
  233. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  234. if list(data.shape[1:]) != [2]:
  235. raise ValueError(
  236. f"The expected image sizes shape is batch dimension plus "
  237. f"{[2]}. You supplied {data.shape}.")
  238. return data
  239. def _validate_pixel_values(
  240. self, data: Union[torch.Tensor, List[torch.Tensor]]
  241. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  242. h = w = self.config.vision_config.image_size
  243. expected_dims = (3, h, w)
  244. def _validate_shape(d: torch.Tensor):
  245. actual_dims = tuple(d.shape[1:])
  246. if actual_dims != expected_dims:
  247. expected_expr = ("num_patches", *map(str, expected_dims))
  248. raise ValueError(
  249. "The expected shape of pixel values in each batch element "
  250. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  251. for d in data:
  252. _validate_shape(d)
  253. return data
  254. def _parse_and_validate_image_input(
  255. self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
  256. pixel_values = kwargs.pop("pixel_values", None)
  257. image_sizes = kwargs.pop("image_sizes", None)
  258. if pixel_values is None:
  259. return None
  260. if not isinstance(pixel_values, (torch.Tensor, list)):
  261. raise ValueError("Incorrect type of pixel values. "
  262. f"Got type: {type(pixel_values)}")
  263. if not isinstance(image_sizes, torch.Tensor):
  264. raise ValueError("Incorrect type of image sizes. "
  265. f"Got type: {type(image_sizes)}")
  266. return LlavaNextImagePixelInputs(
  267. type="pixel_values",
  268. data=self._validate_pixel_values(pixel_values),
  269. image_sizes=self._validate_image_sizes(image_sizes),
  270. )
  271. def _select_image_features(self, image_features: torch.Tensor, *,
  272. strategy: str) -> torch.Tensor:
  273. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  274. if strategy == "default":
  275. return image_features[:, 1:]
  276. elif strategy == "full":
  277. return image_features
  278. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  279. def _image_pixels_to_features(
  280. self,
  281. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  282. pixel_values: torch.Tensor,
  283. ) -> torch.Tensor:
  284. # NOTE: we skip the step to select the vision feature layer since
  285. # this is already done inside the vision tower
  286. image_features = vision_tower(pixel_values)
  287. return self._select_image_features(
  288. image_features,
  289. strategy=self.config.vision_feature_select_strategy,
  290. )
  291. # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
  292. def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
  293. patch_embeddings: torch.Tensor, *,
  294. strategy: str) -> torch.Tensor:
  295. if strategy == "flat":
  296. return patch_embeddings.flatten(0, 1)
  297. if strategy.startswith("spatial"):
  298. height = width = self.config.vision_config.image_size \
  299. // self.config.vision_config.patch_size
  300. base_patch_embeds = patch_embeddings[0]
  301. if height * width != base_patch_embeds.shape[0]:
  302. raise ValueError(
  303. "The number of patches is not consistent with the "
  304. "image size.")
  305. if patch_embeddings.shape[0] > 1:
  306. other_patch_embeds = patch_embeddings[1:]
  307. # Move to CPU to avoid floating-point errors
  308. orig_height, orig_width = image_size.tolist()
  309. # image_aspect_ratio == "anyres"
  310. num_patch_height, num_patch_width = get_anyres_image_grid_shape(
  311. (orig_height, orig_width),
  312. self.config.image_grid_pinpoints,
  313. self.config.vision_config.image_size,
  314. )
  315. other_patch_embeds = other_patch_embeds \
  316. .view(num_patch_height, num_patch_width, height, width, -1)
  317. if "unpad" in strategy:
  318. other_patch_embeds = other_patch_embeds \
  319. .permute(4, 0, 2, 1, 3).contiguous() \
  320. .flatten(1, 2).flatten(2, 3)
  321. other_patch_embeds = unpad_image(other_patch_embeds,
  322. (orig_height, orig_width))
  323. other_patch_embeds = torch.cat((
  324. other_patch_embeds,
  325. self.image_newline[:, None, None] \
  326. .expand(*other_patch_embeds.shape[:-1], 1) \
  327. .to(other_patch_embeds.device),
  328. ), dim=-1)
  329. other_patch_embeds = other_patch_embeds \
  330. .flatten(1, 2).transpose(0, 1)
  331. else:
  332. other_patch_embeds = other_patch_embeds \
  333. .permute(0, 2, 1, 3, 4).contiguous() \
  334. .flatten(0, 3)
  335. merged_patch_embeddings = torch.cat(
  336. (base_patch_embeds, other_patch_embeds), dim=0)
  337. else:
  338. if "unpad" in strategy:
  339. merged_patch_embeddings = torch.cat(
  340. (base_patch_embeds,
  341. self.image_newline[None] \
  342. .to(base_patch_embeds.device)
  343. ), dim=0)
  344. else:
  345. merged_patch_embeddings = base_patch_embeds
  346. return merged_patch_embeddings
  347. raise ValueError(f"Unexpected patch merge strategy: {strategy}")
  348. def _process_image_pixels(
  349. self,
  350. inputs: LlavaNextImagePixelInputs,
  351. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  352. assert self.vision_tower is not None
  353. pixel_values = inputs["data"]
  354. if isinstance(pixel_values, torch.Tensor):
  355. b, num_patches, c, h, w = pixel_values.shape
  356. stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
  357. stacked_image_features = self._image_pixels_to_features(
  358. self.vision_tower, stacked_pixel_values)
  359. stacked_patch_embeddings = self.multi_modal_projector(
  360. stacked_image_features)
  361. return stacked_patch_embeddings.view(
  362. b, num_patches, *stacked_patch_embeddings.shape[1:])
  363. num_patches_per_batch = [v.shape[0] for v in pixel_values]
  364. stacked_pixel_values = torch.cat(pixel_values)
  365. stacked_image_features = self._image_pixels_to_features(
  366. self.vision_tower, stacked_pixel_values)
  367. return [
  368. self.multi_modal_projector(image_features) for image_features in
  369. torch.split(stacked_image_features, num_patches_per_batch)
  370. ]
  371. def _process_image_input(
  372. self,
  373. image_input: LlavaNextImageInputs,
  374. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  375. patch_embeddings = self._process_image_pixels(image_input)
  376. image_sizes = image_input.get("image_sizes")
  377. if image_sizes is None:
  378. batch_size = len(image_input["data"])
  379. vision_config = self.config.vision_config
  380. default_height = default_width = vision_config.image_size
  381. image_sizes = torch.as_tensor([[default_height, default_width]
  382. for _ in range(batch_size)])
  383. return [
  384. self._merge_image_patch_embeddings(image_sizes[i],
  385. patch_features_batch,
  386. strategy="spatial_unpad")
  387. for i, patch_features_batch in enumerate(patch_embeddings)
  388. ]
  389. def forward(
  390. self,
  391. input_ids: torch.Tensor,
  392. positions: torch.Tensor,
  393. kv_caches: List[torch.Tensor],
  394. attn_metadata: AttentionMetadata,
  395. intermediate_tensors: Optional[IntermediateTensors] = None,
  396. **kwargs: object,
  397. ) -> SamplerOutput:
  398. """Run forward pass for LlaVA-NeXT.
  399. One key thing to understand is the `input_ids` already accounts for the
  400. positions of the to-be-inserted image embeddings.
  401. Concretely, consider a text prompt:
  402. `"A chat between a curious human and an artificial intelligence
  403. assistant. The assistant gives helpful, detailed, and polite answers to
  404. the human's questions.
  405. USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
  406. Tokenizer outputs:
  407. `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
  408. 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
  409. 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
  410. 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
  411. 9047, 13566, 29901]`.
  412. To reserve space in KV cache, we have to insert placeholder tokens
  413. before they are inputted to the model, so the input processor prepends
  414. additional image tokens (denoted as `32000`), resulting in:
  415. `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
  416. 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
  417. 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
  418. 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
  419. 319, 1799, 9047, 13566, 29901]`.
  420. Unlike in LLaVA-1.5, the number of image tokens inputted to the language
  421. model depends on the original size of the input image. Including the
  422. original image token in the input, the required number of image tokens
  423. is given by :func:`get_llava_next_image_feature_size`.
  424. This way, the `positions` and `attn_metadata` are consistent
  425. with the `input_ids`.
  426. Args:
  427. input_ids: Flattened (concatenated) input_ids corresponding to a
  428. batch.
  429. pixel_values: The pixels in each grid patch for each input image.
  430. image_sizes: The original `(height, width)` for each input image.
  431. See also:
  432. :class:`LlavaNextImageInputs`
  433. """
  434. image_input = self._parse_and_validate_image_input(**kwargs)
  435. if image_input is not None:
  436. vision_embeddings = self._process_image_input(image_input)
  437. inputs_embeds = self.language_model.model.get_input_embeddings(
  438. input_ids)
  439. inputs_embeds = merge_vision_embeddings(
  440. input_ids, inputs_embeds, vision_embeddings,
  441. self.config.image_token_index)
  442. input_ids = None
  443. else:
  444. inputs_embeds = None
  445. hidden_states = self.language_model.model(input_ids,
  446. positions,
  447. kv_caches,
  448. attn_metadata,
  449. None,
  450. inputs_embeds=inputs_embeds)
  451. return hidden_states
  452. def compute_logits(self, hidden_states: torch.Tensor,
  453. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  454. return self.language_model.compute_logits(hidden_states,
  455. sampling_metadata)
  456. def sample(
  457. self,
  458. logits: torch.Tensor,
  459. sampling_metadata: SamplingMetadata,
  460. ) -> Optional[SamplerOutput]:
  461. return self.language_model.sample(logits, sampling_metadata)
  462. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  463. # prepare weight iterators for components
  464. vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
  465. weights, 4)
  466. # load vision encoder
  467. vit_weights = filter_weights(vit_weights, "vision_tower")
  468. self.vision_tower.load_weights(vit_weights)
  469. # load mlp projector
  470. mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
  471. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  472. for name, loaded_weight in mlp_weights:
  473. param = mlp_params_dict[name]
  474. weight_loader = getattr(param, "weight_loader",
  475. default_weight_loader)
  476. weight_loader(param, loaded_weight)
  477. # load newline
  478. newline_weights = filter_weights(newline_weights, "image_newline")
  479. for name, loaded_weight in newline_weights:
  480. assert name == ""
  481. param = self.image_newline
  482. weight_loader = getattr(param, "weight_loader",
  483. default_weight_loader)
  484. weight_loader(param, loaded_weight)
  485. # load llm backbone
  486. llm_weights = filter_weights(llm_weights, "language_model")
  487. self.language_model.load_weights(llm_weights)