1
0

blip2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689
  1. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  2. TypedDict, Union)
  3. import torch
  4. import torch.nn as nn
  5. from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
  6. apply_chunking_to_forward)
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, MultiModalConfig
  9. from aphrodite.common.sequence import IntermediateTensors, SequenceData
  10. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  11. from aphrodite.modeling.layers.activation import get_act_fn
  12. from aphrodite.modeling.layers.sampler import SamplerOutput
  13. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  14. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  15. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  16. from aphrodite.quantization import QuantizationConfig
  17. from .blip import (BlipVisionModel, dummy_image_for_blip,
  18. get_max_blip_image_tokens)
  19. from .interfaces import SupportsMultiModal
  20. from .utils import (group_weights_with_prefix, init_aphrodite_registered_model,
  21. merge_multimodal_embeddings)
  22. # We use this internally as placeholders since there is no image token
  23. # defined on the HuggingFace repo
  24. BLIP2_IMAGE_TOKEN = "<image>"
  25. BLIP2_IMAGE_TOKEN_ID = 50265
  26. class Blip2ImagePixelInputs(TypedDict):
  27. type: Literal["pixel_values"]
  28. data: torch.Tensor
  29. """Shape: `(batch_size * num_images, num_channels, height, width)`"""
  30. class Blip2ImageEmbeddingInputs(TypedDict):
  31. type: Literal["image_embeds"]
  32. data: torch.Tensor
  33. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  34. `hidden_size` must match the hidden size of language model backbone.
  35. """
  36. Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
  37. class Blip2QFormerMultiHeadAttention(nn.Module):
  38. def __init__(
  39. self,
  40. config: Blip2QFormerConfig,
  41. *,
  42. quant_config: Optional[QuantizationConfig],
  43. cache_config: Optional[CacheConfig],
  44. is_cross_attention: bool = False,
  45. ) -> None:
  46. super().__init__()
  47. self.config = config
  48. if config.hidden_size % config.num_attention_heads != 0:
  49. raise ValueError(
  50. f"The hidden size ({config.hidden_size}) is not a multiple of "
  51. f"the number of attention heads ({config.num_attention_heads})"
  52. )
  53. self.num_attention_heads = config.num_attention_heads
  54. self.attention_head_size = (config.hidden_size //
  55. config.num_attention_heads)
  56. self.all_head_size = self.num_attention_heads * self.attention_head_size
  57. self.scaling = self.attention_head_size**-0.5
  58. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  59. if is_cross_attention:
  60. kv_hidden_size = config.encoder_hidden_size
  61. else:
  62. kv_hidden_size = config.hidden_size
  63. self.key = nn.Linear(kv_hidden_size, self.all_head_size)
  64. self.value = nn.Linear(kv_hidden_size, self.all_head_size)
  65. self.position_embedding_type = getattr(config,
  66. "position_embedding_type",
  67. "absolute")
  68. if self.position_embedding_type != "absolute":
  69. raise NotImplementedError("Unsupported position_embedding_type: "
  70. f"{self.position_embedding_type}")
  71. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  72. def transpose_for_scores(self, x):
  73. x = x.view(*x.size()[:-1], self.num_attention_heads,
  74. self.attention_head_size)
  75. return x.permute(0, 2, 1, 3)
  76. def forward(
  77. self,
  78. hidden_states: torch.Tensor,
  79. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  80. ):
  81. is_cross_attention = encoder_hidden_states is not None
  82. if is_cross_attention:
  83. key_layer = self.transpose_for_scores(
  84. self.key(encoder_hidden_states))
  85. value_layer = self.transpose_for_scores(
  86. self.value(encoder_hidden_states))
  87. else:
  88. key_layer = self.transpose_for_scores(self.key(hidden_states))
  89. value_layer = self.transpose_for_scores(self.value(hidden_states))
  90. mixed_query_layer = self.query(hidden_states)
  91. query_layer = self.transpose_for_scores(mixed_query_layer)
  92. attention_scores = torch.matmul(query_layer,
  93. key_layer.transpose(-1, -2))
  94. attention_probs = torch.softmax(attention_scores * self.scaling,
  95. dim=-1)
  96. # This is actually dropping out entire tokens to attend to, which might
  97. # seem a bit unusual, but is taken from the original Transformer paper.
  98. attention_probs_dropped = self.dropout(attention_probs)
  99. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  100. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  101. context_layer = context_layer.view(*context_layer.size()[:-2],
  102. self.all_head_size)
  103. return context_layer
  104. class Blip2QFormerSelfOutput(nn.Module):
  105. def __init__(self, config: Blip2QFormerConfig) -> None:
  106. super().__init__()
  107. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  108. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  109. eps=config.layer_norm_eps)
  110. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  111. def forward(
  112. self,
  113. hidden_states: torch.Tensor,
  114. input_tensor: torch.Tensor,
  115. ) -> torch.Tensor:
  116. hidden_states = self.dense(hidden_states)
  117. hidden_states = self.dropout(hidden_states)
  118. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  119. return hidden_states
  120. class Blip2QFormerAttention(nn.Module):
  121. def __init__(
  122. self,
  123. config: Blip2QFormerConfig,
  124. *,
  125. quant_config: Optional[QuantizationConfig],
  126. cache_config: Optional[CacheConfig],
  127. is_cross_attention: bool = False,
  128. ) -> None:
  129. super().__init__()
  130. self.attention = Blip2QFormerMultiHeadAttention(
  131. config,
  132. quant_config=quant_config,
  133. cache_config=cache_config,
  134. is_cross_attention=is_cross_attention,
  135. )
  136. self.output = Blip2QFormerSelfOutput(config)
  137. def forward(
  138. self,
  139. hidden_states: torch.Tensor,
  140. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  141. ) -> Tuple[torch.Tensor]:
  142. self_output = self.attention(
  143. hidden_states,
  144. encoder_hidden_states=encoder_hidden_states,
  145. )
  146. attention_output = self.output(self_output, hidden_states)
  147. return attention_output
  148. class Blip2QFormerIntermediate(nn.Module):
  149. def __init__(self, config: Blip2QFormerConfig) -> None:
  150. super().__init__()
  151. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  152. self.intermediate_act_fn = get_act_fn(config.hidden_act)
  153. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  154. hidden_states = self.dense(hidden_states)
  155. hidden_states = self.intermediate_act_fn(hidden_states)
  156. return hidden_states
  157. class Blip2QFormerOutput(nn.Module):
  158. def __init__(self, config: Blip2QFormerConfig) -> None:
  159. super().__init__()
  160. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  161. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  162. eps=config.layer_norm_eps)
  163. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  164. def forward(
  165. self,
  166. hidden_states: torch.Tensor,
  167. input_tensor: torch.Tensor,
  168. ) -> torch.Tensor:
  169. hidden_states = self.dense(hidden_states)
  170. hidden_states = self.dropout(hidden_states)
  171. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  172. return hidden_states
  173. class Blip2QFormerLayer(nn.Module):
  174. def __init__(
  175. self,
  176. config: Blip2QFormerConfig,
  177. *,
  178. quant_config: Optional[QuantizationConfig],
  179. cache_config: Optional[CacheConfig],
  180. layer_idx: int,
  181. ) -> None:
  182. super().__init__()
  183. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  184. self.seq_len_dim = 1
  185. self.attention = Blip2QFormerAttention(config,
  186. quant_config=quant_config,
  187. cache_config=cache_config)
  188. self.layer_idx = layer_idx
  189. if layer_idx % config.cross_attention_frequency == 0:
  190. self.crossattention = Blip2QFormerAttention(
  191. config,
  192. quant_config=quant_config,
  193. cache_config=cache_config,
  194. is_cross_attention=True)
  195. self.has_cross_attention = True
  196. else:
  197. self.has_cross_attention = False
  198. self.intermediate_query = Blip2QFormerIntermediate(config)
  199. self.output_query = Blip2QFormerOutput(config)
  200. def forward(
  201. self,
  202. hidden_states: torch.FloatTensor,
  203. encoder_hidden_states: torch.FloatTensor,
  204. query_length: int,
  205. ):
  206. attention_output = self.attention(hidden_states)
  207. if query_length > 0:
  208. query_attention_output = attention_output[:, :query_length, :]
  209. if self.has_cross_attention:
  210. query_attention_output = self.crossattention(
  211. query_attention_output,
  212. encoder_hidden_states=encoder_hidden_states,
  213. )
  214. layer_output = apply_chunking_to_forward(
  215. self.feed_forward_chunk_query,
  216. self.chunk_size_feed_forward,
  217. self.seq_len_dim,
  218. query_attention_output,
  219. )
  220. if attention_output.shape[1] > query_length:
  221. layer_output_text = apply_chunking_to_forward(
  222. self.feed_forward_chunk,
  223. self.chunk_size_feed_forward,
  224. self.seq_len_dim,
  225. attention_output[:, query_length:, :],
  226. )
  227. layer_output = torch.cat([layer_output, layer_output_text],
  228. dim=1)
  229. else:
  230. layer_output = apply_chunking_to_forward(
  231. self.feed_forward_chunk,
  232. self.chunk_size_feed_forward,
  233. self.seq_len_dim,
  234. attention_output,
  235. )
  236. return layer_output
  237. def feed_forward_chunk(self,
  238. attention_output: torch.Tensor) -> torch.Tensor:
  239. intermediate_output = self.intermediate(attention_output)
  240. layer_output = self.output(intermediate_output, attention_output)
  241. return layer_output
  242. def feed_forward_chunk_query(
  243. self, attention_output: torch.Tensor) -> torch.Tensor:
  244. intermediate_output = self.intermediate_query(attention_output)
  245. layer_output = self.output_query(intermediate_output, attention_output)
  246. return layer_output
  247. class Blip2QFormerEncoder(nn.Module):
  248. def __init__(
  249. self,
  250. config: Blip2QFormerConfig,
  251. *,
  252. quant_config: Optional[QuantizationConfig],
  253. cache_config: Optional[CacheConfig],
  254. ) -> None:
  255. super().__init__()
  256. self.config = config
  257. self.layer = nn.ModuleList([
  258. Blip2QFormerLayer(config,
  259. quant_config=quant_config,
  260. cache_config=cache_config,
  261. layer_idx=layer_idx)
  262. for layer_idx in range(config.num_hidden_layers)
  263. ])
  264. def forward(
  265. self,
  266. hidden_states: torch.FloatTensor,
  267. encoder_hidden_states: torch.FloatTensor,
  268. query_length: int,
  269. ) -> torch.Tensor:
  270. for i in range(self.config.num_hidden_layers):
  271. layer_module = self.layer[i]
  272. hidden_states = layer_module(
  273. hidden_states,
  274. encoder_hidden_states=encoder_hidden_states,
  275. query_length=query_length,
  276. )
  277. return hidden_states
  278. # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
  279. class Blip2QFormerModel(nn.Module):
  280. def __init__(
  281. self,
  282. config: Blip2QFormerConfig,
  283. *,
  284. quant_config: Optional[QuantizationConfig],
  285. cache_config: Optional[CacheConfig],
  286. ) -> None:
  287. super().__init__()
  288. self.config = config
  289. self.layernorm = nn.LayerNorm(config.hidden_size,
  290. eps=config.layer_norm_eps)
  291. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  292. self.encoder = Blip2QFormerEncoder(config,
  293. quant_config=quant_config,
  294. cache_config=cache_config)
  295. def forward(
  296. self,
  297. query_embeds: torch.FloatTensor,
  298. encoder_hidden_states: torch.FloatTensor,
  299. ) -> torch.Tensor:
  300. query_length = query_embeds.shape[1]
  301. embedding_output = self.layernorm(query_embeds)
  302. embedding_output = self.dropout(embedding_output)
  303. sequence_output = self.encoder(
  304. embedding_output,
  305. encoder_hidden_states=encoder_hidden_states,
  306. query_length=query_length,
  307. )
  308. return sequence_output
  309. def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
  310. return hf_config.num_query_tokens
  311. def get_max_blip2_image_tokens(ctx: InputContext):
  312. hf_config = ctx.get_hf_config(Blip2Config)
  313. vision_config = hf_config.vision_config
  314. if isinstance(vision_config, Blip2VisionConfig):
  315. return get_max_blip_image_tokens(vision_config)
  316. msg = f"Unsupported vision config: {type(vision_config)}"
  317. raise NotImplementedError(msg)
  318. def dummy_seq_data_for_blip2(
  319. hf_config: Blip2Config,
  320. seq_len: int,
  321. num_images: int,
  322. *,
  323. image_token_id: int,
  324. image_feature_size_override: Optional[int] = None,
  325. ):
  326. if image_feature_size_override is None:
  327. image_feature_size = get_blip2_image_feature_size(hf_config)
  328. else:
  329. image_feature_size = image_feature_size_override
  330. return SequenceData.from_token_counts(
  331. (image_token_id, image_feature_size * num_images),
  332. (0, seq_len - image_feature_size * num_images),
  333. )
  334. def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
  335. mm_counts: Mapping[str, int]):
  336. hf_config = ctx.get_hf_config(Blip2Config)
  337. vision_config = hf_config.vision_config
  338. num_images = mm_counts["image"]
  339. seq_data = dummy_seq_data_for_blip2(
  340. hf_config,
  341. seq_len,
  342. num_images,
  343. image_token_id=BLIP2_IMAGE_TOKEN_ID,
  344. )
  345. if isinstance(vision_config, Blip2VisionConfig):
  346. mm_data = dummy_image_for_blip(vision_config, num_images)
  347. return seq_data, mm_data
  348. msg = f"Unsupported vision config: {type(vision_config)}"
  349. raise NotImplementedError(msg)
  350. def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
  351. multi_modal_data = llm_inputs.get("multi_modal_data")
  352. if multi_modal_data is None or "image" not in multi_modal_data:
  353. return llm_inputs
  354. hf_config = ctx.get_hf_config(Blip2Config)
  355. image_feature_size = get_blip2_image_feature_size(hf_config)
  356. # The original model places image tokens at the front
  357. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
  358. new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
  359. new_token_ids += llm_inputs["prompt_token_ids"]
  360. new_prompt = llm_inputs.get("prompt")
  361. if new_prompt is not None:
  362. new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
  363. return LLMInputs(prompt_token_ids=new_token_ids,
  364. prompt=new_prompt,
  365. multi_modal_data=multi_modal_data)
  366. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  367. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
  368. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
  369. @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
  370. class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
  371. def __init__(self,
  372. config: Blip2Config,
  373. multimodal_config: MultiModalConfig,
  374. cache_config: Optional[CacheConfig] = None,
  375. quant_config: Optional[QuantizationConfig] = None) -> None:
  376. super().__init__()
  377. self.config = config
  378. self.multimodal_config = multimodal_config
  379. # TODO: Optionally initializes this for supporting embeddings.
  380. self.vision_model = BlipVisionModel(config.vision_config)
  381. self.query_tokens = nn.Parameter(
  382. torch.zeros(1, config.num_query_tokens,
  383. config.qformer_config.hidden_size))
  384. self.qformer = Blip2QFormerModel(config.qformer_config,
  385. cache_config=cache_config,
  386. quant_config=quant_config)
  387. self.language_projection = nn.Linear(
  388. config.qformer_config.hidden_size,
  389. config.text_config.hidden_size,
  390. bias=True,
  391. )
  392. self.language_model = init_aphrodite_registered_model(
  393. config.text_config, cache_config, quant_config)
  394. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  395. h = w = self.config.vision_config.image_size
  396. expected_dims = (3, h, w)
  397. actual_dims = tuple(data.shape[1:])
  398. if actual_dims != expected_dims:
  399. expected_expr = ("batch_size", *map(str, expected_dims))
  400. raise ValueError(
  401. f"The expected shape of pixel values is {expected_expr}. "
  402. f"You supplied {tuple(data.shape)}.")
  403. return data
  404. def _parse_and_validate_image_input(
  405. self, **kwargs: object) -> Optional[Blip2ImageInputs]:
  406. pixel_values = kwargs.pop("pixel_values", None)
  407. image_embeds = kwargs.pop("image_embeds", None)
  408. if pixel_values is None and image_embeds is None:
  409. return None
  410. if pixel_values is not None:
  411. if not isinstance(pixel_values, torch.Tensor):
  412. raise ValueError("Incorrect type of pixel values. "
  413. f"Got type: {type(pixel_values)}")
  414. # Remove the N dimension until multiple images are supported.
  415. pixel_values = pixel_values.squeeze(1)
  416. return Blip2ImagePixelInputs(
  417. type="pixel_values",
  418. data=self._validate_pixel_values(pixel_values),
  419. )
  420. if image_embeds is not None:
  421. if not isinstance(image_embeds, torch.Tensor):
  422. raise ValueError("Incorrect type of image embeddings. "
  423. f"Got type: {type(image_embeds)}")
  424. # Remove the N dimension until multiple images are supported.
  425. image_embeds = image_embeds.squeeze(1)
  426. return Blip2ImageEmbeddingInputs(
  427. type="image_embeds",
  428. data=image_embeds,
  429. )
  430. raise AssertionError("This line should be unreachable.")
  431. def _image_pixels_to_features(self, vision_model: BlipVisionModel,
  432. pixel_values: torch.Tensor) -> torch.Tensor:
  433. # NOTE: we skip the step to select the vision feature layer since
  434. # this is already done inside the vision tower
  435. image_features = vision_model(pixel_values)
  436. return image_features
  437. def _process_image_pixels(self,
  438. inputs: Blip2ImagePixelInputs) -> torch.Tensor:
  439. assert self.vision_model is not None
  440. pixel_values = inputs["data"]
  441. return self._image_pixels_to_features(self.vision_model, pixel_values)
  442. def _process_image_input(self,
  443. image_input: Blip2ImageInputs) -> torch.Tensor:
  444. if image_input["type"] == "image_embeds":
  445. return image_input["data"]
  446. assert self.vision_model is not None
  447. image_features = self._process_image_pixels(image_input)
  448. query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
  449. -1)
  450. query_output = self.qformer(
  451. query_embeds=query_tokens,
  452. encoder_hidden_states=image_features,
  453. )
  454. return self.language_projection(query_output)
  455. def forward(
  456. self,
  457. input_ids: torch.Tensor,
  458. positions: torch.Tensor,
  459. kv_caches: List[torch.Tensor],
  460. attn_metadata: AttentionMetadata,
  461. intermediate_tensors: Optional[IntermediateTensors] = None,
  462. **kwargs: object,
  463. ) -> SamplerOutput:
  464. """Run forward pass for BLIP-2.
  465. One key thing to understand is the `input_ids` already accounts for the
  466. positions of the to-be-inserted image embeddings.
  467. Concretely, consider a text prompt:
  468. `"Question: What's the content of the image? Answer:"`.
  469. Tokenizer outputs:
  470. `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
  471. To reserve space in KV cache, we have to insert placeholder tokens
  472. before they are inputted to the model, so the input processor prepends
  473. dummy tokens (denoted as `50265`), resulting in:
  474. `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
  475. We insert 32 tokens since it corresponds to the number of query
  476. embeddings outputted by the Q-Former and inputted to the language model.
  477. This way, the `positions` and `attn_metadata` are consistent
  478. with the `input_ids`.
  479. Args:
  480. input_ids: Flattened (concatenated) input_ids corresponding to a
  481. batch.
  482. pixel_values: The pixels in each input image.
  483. See also:
  484. :class:`Blip2ImageInputs`
  485. """
  486. image_input = self._parse_and_validate_image_input(**kwargs)
  487. if image_input is not None:
  488. vision_embeddings = self._process_image_input(image_input)
  489. inputs_embeds = self.language_model.model.get_input_embeddings(
  490. input_ids)
  491. inputs_embeds = merge_multimodal_embeddings(
  492. input_ids, inputs_embeds, vision_embeddings,
  493. BLIP2_IMAGE_TOKEN_ID)
  494. input_ids = None
  495. else:
  496. inputs_embeds = None
  497. hidden_states = self.language_model.model(input_ids,
  498. positions,
  499. kv_caches,
  500. attn_metadata,
  501. inputs_embeds=inputs_embeds)
  502. return hidden_states
  503. def compute_logits(
  504. self,
  505. hidden_states: torch.Tensor,
  506. sampling_metadata: SamplingMetadata,
  507. ) -> Optional[torch.Tensor]:
  508. return self.language_model.compute_logits(hidden_states,
  509. sampling_metadata)
  510. def sample(
  511. self,
  512. logits: torch.Tensor,
  513. sampling_metadata: SamplingMetadata,
  514. ) -> Optional[SamplerOutput]:
  515. return self.language_model.sample(logits, sampling_metadata)
  516. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  517. # prepare weight iterators for components
  518. weights_group = group_weights_with_prefix(weights)
  519. # load vision encoder
  520. self.vision_model.load_weights(weights_group["vision_model"])
  521. # load query tokens
  522. for name, loaded_weight in weights_group["query_tokens"]:
  523. assert name == ""
  524. param = self.query_tokens
  525. weight_loader = getattr(param, "weight_loader",
  526. default_weight_loader)
  527. weight_loader(param, loaded_weight)
  528. # load qformer
  529. qformer_params_dict = dict(self.qformer.named_parameters())
  530. for name, loaded_weight in weights_group["qformer"]:
  531. param = qformer_params_dict[name]
  532. weight_loader = getattr(param, "weight_loader",
  533. default_weight_loader)
  534. weight_loader(param, loaded_weight)
  535. # load mlp projector
  536. mlp_params_dict = dict(self.language_projection.named_parameters())
  537. for name, loaded_weight in weights_group["language_projection"]:
  538. param = mlp_params_dict[name]
  539. weight_loader = getattr(param, "weight_loader",
  540. default_weight_loader)
  541. weight_loader(param, loaded_weight)
  542. # load llm backbone
  543. self.language_model.load_weights(weights_group["language_model"])