blip2.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. from array import array
  2. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  3. TypedDict, Union)
  4. import torch
  5. import torch.nn as nn
  6. from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
  7. apply_chunking_to_forward)
  8. from aphrodite.attention import AttentionMetadata
  9. from aphrodite.common.config import CacheConfig, MultiModalConfig
  10. from aphrodite.common.sequence import IntermediateTensors, SequenceData
  11. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  12. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  13. from aphrodite.modeling.layers.activation import get_act_fn
  14. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  15. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  16. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  17. from aphrodite.modeling.models.opt import OPTModel
  18. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  19. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  20. from aphrodite.quantization import QuantizationConfig
  21. from .blip import (BlipVisionModel, dummy_image_for_blip,
  22. get_max_blip_image_tokens)
  23. from .interfaces import SupportsMultiModal
  24. from .utils import merge_multimodal_embeddings
  25. _KEYS_TO_MODIFY_MAPPING = {
  26. "language_model.lm_head": "lm_head",
  27. "language_model.model": "language_model",
  28. }
  29. # We use this internally as placeholders since there is no image token
  30. # defined on the HuggingFace repo
  31. BLIP2_IMAGE_TOKEN = "<image>"
  32. BLIP2_IMAGE_TOKEN_ID = 50265
  33. class Blip2ImagePixelInputs(TypedDict):
  34. type: Literal["pixel_values"]
  35. data: torch.Tensor
  36. """Shape: `(batch_size * num_images, num_channels, height, width)`"""
  37. class Blip2ImageEmbeddingInputs(TypedDict):
  38. type: Literal["image_embeds"]
  39. data: torch.Tensor
  40. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  41. `hidden_size` must match the hidden size of language model backbone.
  42. """
  43. Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
  44. class Blip2QFormerMultiHeadAttention(nn.Module):
  45. def __init__(
  46. self,
  47. config: Blip2QFormerConfig,
  48. *,
  49. quant_config: Optional[QuantizationConfig],
  50. cache_config: Optional[CacheConfig],
  51. is_cross_attention: bool = False,
  52. ) -> None:
  53. super().__init__()
  54. self.config = config
  55. if config.hidden_size % config.num_attention_heads != 0:
  56. raise ValueError(
  57. f"The hidden size ({config.hidden_size}) is not a multiple of "
  58. f"the number of attention heads ({config.num_attention_heads})"
  59. )
  60. self.num_attention_heads = config.num_attention_heads
  61. self.attention_head_size = (config.hidden_size //
  62. config.num_attention_heads)
  63. self.all_head_size = self.num_attention_heads * self.attention_head_size
  64. self.scaling = self.attention_head_size**-0.5
  65. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  66. if is_cross_attention:
  67. kv_hidden_size = config.encoder_hidden_size
  68. else:
  69. kv_hidden_size = config.hidden_size
  70. self.key = nn.Linear(kv_hidden_size, self.all_head_size)
  71. self.value = nn.Linear(kv_hidden_size, self.all_head_size)
  72. self.position_embedding_type = getattr(config,
  73. "position_embedding_type",
  74. "absolute")
  75. if self.position_embedding_type != "absolute":
  76. raise NotImplementedError("Unsupported position_embedding_type: "
  77. f"{self.position_embedding_type}")
  78. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  79. def transpose_for_scores(self, x):
  80. x = x.view(*x.size()[:-1], self.num_attention_heads,
  81. self.attention_head_size)
  82. return x.permute(0, 2, 1, 3)
  83. def forward(
  84. self,
  85. hidden_states: torch.Tensor,
  86. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  87. ):
  88. is_cross_attention = encoder_hidden_states is not None
  89. if is_cross_attention:
  90. key_layer = self.transpose_for_scores(
  91. self.key(encoder_hidden_states))
  92. value_layer = self.transpose_for_scores(
  93. self.value(encoder_hidden_states))
  94. else:
  95. key_layer = self.transpose_for_scores(self.key(hidden_states))
  96. value_layer = self.transpose_for_scores(self.value(hidden_states))
  97. mixed_query_layer = self.query(hidden_states)
  98. query_layer = self.transpose_for_scores(mixed_query_layer)
  99. attention_scores = torch.matmul(query_layer,
  100. key_layer.transpose(-1, -2))
  101. attention_probs = torch.softmax(attention_scores * self.scaling,
  102. dim=-1)
  103. # This is actually dropping out entire tokens to attend to, which might
  104. # seem a bit unusual, but is taken from the original Transformer paper.
  105. attention_probs_dropped = self.dropout(attention_probs)
  106. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  107. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  108. context_layer = context_layer.view(*context_layer.size()[:-2],
  109. self.all_head_size)
  110. return context_layer
  111. class Blip2QFormerSelfOutput(nn.Module):
  112. def __init__(self, config: Blip2QFormerConfig) -> None:
  113. super().__init__()
  114. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  115. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  116. eps=config.layer_norm_eps)
  117. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  118. def forward(
  119. self,
  120. hidden_states: torch.Tensor,
  121. input_tensor: torch.Tensor,
  122. ) -> torch.Tensor:
  123. hidden_states = self.dense(hidden_states)
  124. hidden_states = self.dropout(hidden_states)
  125. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  126. return hidden_states
  127. class Blip2QFormerAttention(nn.Module):
  128. def __init__(
  129. self,
  130. config: Blip2QFormerConfig,
  131. *,
  132. quant_config: Optional[QuantizationConfig],
  133. cache_config: Optional[CacheConfig],
  134. is_cross_attention: bool = False,
  135. ) -> None:
  136. super().__init__()
  137. self.attention = Blip2QFormerMultiHeadAttention(
  138. config,
  139. quant_config=quant_config,
  140. cache_config=cache_config,
  141. is_cross_attention=is_cross_attention,
  142. )
  143. self.output = Blip2QFormerSelfOutput(config)
  144. def forward(
  145. self,
  146. hidden_states: torch.Tensor,
  147. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  148. ) -> Tuple[torch.Tensor]:
  149. self_output = self.attention(
  150. hidden_states,
  151. encoder_hidden_states=encoder_hidden_states,
  152. )
  153. attention_output = self.output(self_output, hidden_states)
  154. return attention_output
  155. class Blip2QFormerIntermediate(nn.Module):
  156. def __init__(self, config: Blip2QFormerConfig) -> None:
  157. super().__init__()
  158. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  159. self.intermediate_act_fn = get_act_fn(config.hidden_act)
  160. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  161. hidden_states = self.dense(hidden_states)
  162. hidden_states = self.intermediate_act_fn(hidden_states)
  163. return hidden_states
  164. class Blip2QFormerOutput(nn.Module):
  165. def __init__(self, config: Blip2QFormerConfig) -> None:
  166. super().__init__()
  167. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  168. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  169. eps=config.layer_norm_eps)
  170. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  171. def forward(
  172. self,
  173. hidden_states: torch.Tensor,
  174. input_tensor: torch.Tensor,
  175. ) -> torch.Tensor:
  176. hidden_states = self.dense(hidden_states)
  177. hidden_states = self.dropout(hidden_states)
  178. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  179. return hidden_states
  180. class Blip2QFormerLayer(nn.Module):
  181. def __init__(
  182. self,
  183. config: Blip2QFormerConfig,
  184. *,
  185. quant_config: Optional[QuantizationConfig],
  186. cache_config: Optional[CacheConfig],
  187. layer_idx: int,
  188. ) -> None:
  189. super().__init__()
  190. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  191. self.seq_len_dim = 1
  192. self.attention = Blip2QFormerAttention(config,
  193. quant_config=quant_config,
  194. cache_config=cache_config)
  195. self.layer_idx = layer_idx
  196. if layer_idx % config.cross_attention_frequency == 0:
  197. self.crossattention = Blip2QFormerAttention(
  198. config,
  199. quant_config=quant_config,
  200. cache_config=cache_config,
  201. is_cross_attention=True)
  202. self.has_cross_attention = True
  203. else:
  204. self.has_cross_attention = False
  205. self.intermediate_query = Blip2QFormerIntermediate(config)
  206. self.output_query = Blip2QFormerOutput(config)
  207. def forward(
  208. self,
  209. hidden_states: torch.FloatTensor,
  210. encoder_hidden_states: torch.FloatTensor,
  211. query_length: int,
  212. ):
  213. attention_output = self.attention(hidden_states)
  214. if query_length > 0:
  215. query_attention_output = attention_output[:, :query_length, :]
  216. if self.has_cross_attention:
  217. query_attention_output = self.crossattention(
  218. query_attention_output,
  219. encoder_hidden_states=encoder_hidden_states,
  220. )
  221. layer_output = apply_chunking_to_forward(
  222. self.feed_forward_chunk_query,
  223. self.chunk_size_feed_forward,
  224. self.seq_len_dim,
  225. query_attention_output,
  226. )
  227. if attention_output.shape[1] > query_length:
  228. layer_output_text = apply_chunking_to_forward(
  229. self.feed_forward_chunk,
  230. self.chunk_size_feed_forward,
  231. self.seq_len_dim,
  232. attention_output[:, query_length:, :],
  233. )
  234. layer_output = torch.cat([layer_output, layer_output_text],
  235. dim=1)
  236. else:
  237. layer_output = apply_chunking_to_forward(
  238. self.feed_forward_chunk,
  239. self.chunk_size_feed_forward,
  240. self.seq_len_dim,
  241. attention_output,
  242. )
  243. return layer_output
  244. def feed_forward_chunk(self,
  245. attention_output: torch.Tensor) -> torch.Tensor:
  246. intermediate_output = self.intermediate(attention_output)
  247. layer_output = self.output(intermediate_output, attention_output)
  248. return layer_output
  249. def feed_forward_chunk_query(
  250. self, attention_output: torch.Tensor) -> torch.Tensor:
  251. intermediate_output = self.intermediate_query(attention_output)
  252. layer_output = self.output_query(intermediate_output, attention_output)
  253. return layer_output
  254. class Blip2QFormerEncoder(nn.Module):
  255. def __init__(
  256. self,
  257. config: Blip2QFormerConfig,
  258. *,
  259. quant_config: Optional[QuantizationConfig],
  260. cache_config: Optional[CacheConfig],
  261. ) -> None:
  262. super().__init__()
  263. self.config = config
  264. self.layer = nn.ModuleList([
  265. Blip2QFormerLayer(config,
  266. quant_config=quant_config,
  267. cache_config=cache_config,
  268. layer_idx=layer_idx)
  269. for layer_idx in range(config.num_hidden_layers)
  270. ])
  271. def forward(
  272. self,
  273. hidden_states: torch.FloatTensor,
  274. encoder_hidden_states: torch.FloatTensor,
  275. query_length: int,
  276. ) -> torch.Tensor:
  277. for i in range(self.config.num_hidden_layers):
  278. layer_module = self.layer[i]
  279. hidden_states = layer_module(
  280. hidden_states,
  281. encoder_hidden_states=encoder_hidden_states,
  282. query_length=query_length,
  283. )
  284. return hidden_states
  285. # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
  286. class Blip2QFormerModel(nn.Module):
  287. def __init__(
  288. self,
  289. config: Blip2QFormerConfig,
  290. *,
  291. quant_config: Optional[QuantizationConfig],
  292. cache_config: Optional[CacheConfig],
  293. ) -> None:
  294. super().__init__()
  295. self.config = config
  296. self.layernorm = nn.LayerNorm(config.hidden_size,
  297. eps=config.layer_norm_eps)
  298. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  299. self.encoder = Blip2QFormerEncoder(config,
  300. quant_config=quant_config,
  301. cache_config=cache_config)
  302. def forward(
  303. self,
  304. query_embeds: torch.FloatTensor,
  305. encoder_hidden_states: torch.FloatTensor,
  306. ) -> torch.Tensor:
  307. query_length = query_embeds.shape[1]
  308. embedding_output = self.layernorm(query_embeds)
  309. embedding_output = self.dropout(embedding_output)
  310. sequence_output = self.encoder(
  311. embedding_output,
  312. encoder_hidden_states=encoder_hidden_states,
  313. query_length=query_length,
  314. )
  315. return sequence_output
  316. def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
  317. return hf_config.num_query_tokens
  318. def get_max_blip2_image_tokens(ctx: InputContext):
  319. hf_config = ctx.get_hf_config(Blip2Config)
  320. vision_config = hf_config.vision_config
  321. if isinstance(vision_config, Blip2VisionConfig):
  322. return get_max_blip_image_tokens(vision_config)
  323. msg = f"Unsupported vision config: {type(vision_config)}"
  324. raise NotImplementedError(msg)
  325. def dummy_seq_data_for_blip2(
  326. hf_config: Blip2Config,
  327. seq_len: int,
  328. num_images: int,
  329. *,
  330. image_token_id: int,
  331. image_feature_size_override: Optional[int] = None,
  332. ):
  333. if image_feature_size_override is None:
  334. image_feature_size = get_blip2_image_feature_size(hf_config)
  335. else:
  336. image_feature_size = image_feature_size_override
  337. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  338. [image_token_id]) * image_feature_size * num_images
  339. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  340. [0]) * (seq_len - image_feature_size * num_images)
  341. return SequenceData(token_ids)
  342. def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
  343. mm_counts: Mapping[str, int]):
  344. hf_config = ctx.get_hf_config(Blip2Config)
  345. vision_config = hf_config.vision_config
  346. num_images = mm_counts["image"]
  347. seq_data = dummy_seq_data_for_blip2(
  348. hf_config,
  349. seq_len,
  350. num_images,
  351. image_token_id=BLIP2_IMAGE_TOKEN_ID,
  352. )
  353. if isinstance(vision_config, Blip2VisionConfig):
  354. mm_data = dummy_image_for_blip(vision_config, num_images)
  355. return seq_data, mm_data
  356. msg = f"Unsupported vision config: {type(vision_config)}"
  357. raise NotImplementedError(msg)
  358. def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
  359. multi_modal_data = llm_inputs.get("multi_modal_data")
  360. if multi_modal_data is None or "image" not in multi_modal_data:
  361. return llm_inputs
  362. hf_config = ctx.get_hf_config(Blip2Config)
  363. image_feature_size = get_blip2_image_feature_size(hf_config)
  364. # The original model places image tokens at the front
  365. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
  366. new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
  367. new_token_ids += llm_inputs["prompt_token_ids"]
  368. new_prompt = llm_inputs.get("prompt")
  369. if new_prompt is not None:
  370. new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
  371. return LLMInputs(prompt_token_ids=new_token_ids,
  372. prompt=new_prompt,
  373. multi_modal_data=multi_modal_data)
  374. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  375. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
  376. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
  377. @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
  378. class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
  379. def __init__(self,
  380. config: Blip2Config,
  381. multimodal_config: MultiModalConfig,
  382. cache_config: Optional[CacheConfig] = None,
  383. quant_config: Optional[QuantizationConfig] = None) -> None:
  384. super().__init__()
  385. self.config = config
  386. self.multimodal_config = multimodal_config
  387. # TODO: Optionally initializes this for supporting embeddings.
  388. self.vision_model = BlipVisionModel(config.vision_config)
  389. self.query_tokens = nn.Parameter(
  390. torch.zeros(1, config.num_query_tokens,
  391. config.qformer_config.hidden_size))
  392. self.qformer = Blip2QFormerModel(config.qformer_config,
  393. cache_config=cache_config,
  394. quant_config=quant_config)
  395. self.language_projection = nn.Linear(
  396. config.qformer_config.hidden_size,
  397. config.text_config.hidden_size,
  398. bias=True,
  399. )
  400. self.quant_config = quant_config
  401. self.language_model = OPTModel(config.text_config, cache_config,
  402. quant_config)
  403. self.unpadded_vocab_size = config.text_config.vocab_size
  404. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
  405. self.sampler = Sampler()
  406. def get_lm_head(self):
  407. return self.language_model.decoder.embed_tokens
  408. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  409. h = w = self.config.vision_config.image_size
  410. expected_dims = (3, h, w)
  411. actual_dims = tuple(data.shape[1:])
  412. if actual_dims != expected_dims:
  413. expected_expr = ("batch_size", *map(str, expected_dims))
  414. raise ValueError(
  415. f"The expected shape of pixel values is {expected_expr}. "
  416. f"You supplied {tuple(data.shape)}.")
  417. return data
  418. def _parse_and_validate_image_input(
  419. self, **kwargs: object) -> Optional[Blip2ImageInputs]:
  420. pixel_values = kwargs.pop("pixel_values", None)
  421. image_embeds = kwargs.pop("image_embeds", None)
  422. if pixel_values is None and image_embeds is None:
  423. return None
  424. if pixel_values is not None:
  425. if not isinstance(pixel_values, torch.Tensor):
  426. raise ValueError("Incorrect type of pixel values. "
  427. f"Got type: {type(pixel_values)}")
  428. # Remove the N dimension until multiple images are supported.
  429. pixel_values = pixel_values.squeeze(1)
  430. return Blip2ImagePixelInputs(
  431. type="pixel_values",
  432. data=self._validate_pixel_values(pixel_values),
  433. )
  434. if image_embeds is not None:
  435. if not isinstance(image_embeds, torch.Tensor):
  436. raise ValueError("Incorrect type of image embeddings. "
  437. f"Got type: {type(image_embeds)}")
  438. # Remove the N dimension until multiple images are supported.
  439. image_embeds = image_embeds.squeeze(1)
  440. return Blip2ImageEmbeddingInputs(
  441. type="image_embeds",
  442. data=image_embeds,
  443. )
  444. raise AssertionError("This line should be unreachable.")
  445. def _image_pixels_to_features(self, vision_model: BlipVisionModel,
  446. pixel_values: torch.Tensor) -> torch.Tensor:
  447. # NOTE: we skip the step to select the vision feature layer since
  448. # this is already done inside the vision tower
  449. image_features = vision_model(pixel_values)
  450. return image_features
  451. def _process_image_pixels(self,
  452. inputs: Blip2ImagePixelInputs) -> torch.Tensor:
  453. assert self.vision_model is not None
  454. pixel_values = inputs["data"]
  455. return self._image_pixels_to_features(self.vision_model, pixel_values)
  456. def _process_image_input(self,
  457. image_input: Blip2ImageInputs) -> torch.Tensor:
  458. if image_input["type"] == "image_embeds":
  459. return image_input["data"]
  460. assert self.vision_model is not None
  461. image_features = self._process_image_pixels(image_input)
  462. query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
  463. -1)
  464. query_output = self.qformer(
  465. query_embeds=query_tokens,
  466. encoder_hidden_states=image_features,
  467. )
  468. return self.language_projection(query_output)
  469. def forward(
  470. self,
  471. input_ids: torch.Tensor,
  472. positions: torch.Tensor,
  473. kv_caches: List[torch.Tensor],
  474. attn_metadata: AttentionMetadata,
  475. intermediate_tensors: Optional[IntermediateTensors] = None,
  476. **kwargs: object,
  477. ) -> SamplerOutput:
  478. """Run forward pass for BLIP-2.
  479. One key thing to understand is the `input_ids` already accounts for the
  480. positions of the to-be-inserted image embeddings.
  481. Concretely, consider a text prompt:
  482. `"Question: What's the content of the image? Answer:"`.
  483. Tokenizer outputs:
  484. `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
  485. To reserve space in KV cache, we have to insert placeholder tokens
  486. before they are inputted to the model, so the input processor prepends
  487. dummy tokens (denoted as `50265`), resulting in:
  488. `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
  489. We insert 32 tokens since it corresponds to the number of query
  490. embeddings outputted by the Q-Former and inputted to the language model.
  491. This way, the `positions` and `attn_metadata` are consistent
  492. with the `input_ids`.
  493. Args:
  494. input_ids: Flattened (concatenated) input_ids corresponding to a
  495. batch.
  496. pixel_values: The pixels in each input image.
  497. See also:
  498. :class:`Blip2ImageInputs`
  499. """
  500. image_input = self._parse_and_validate_image_input(**kwargs)
  501. if image_input is not None:
  502. vision_embeddings = self._process_image_input(image_input)
  503. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  504. inputs_embeds = merge_multimodal_embeddings(
  505. input_ids, inputs_embeds, vision_embeddings,
  506. BLIP2_IMAGE_TOKEN_ID)
  507. input_ids = None
  508. else:
  509. inputs_embeds = None
  510. hidden_states = self.language_model(input_ids,
  511. positions,
  512. kv_caches,
  513. attn_metadata,
  514. inputs_embeds=inputs_embeds)
  515. return hidden_states
  516. def compute_logits(
  517. self,
  518. hidden_states: torch.Tensor,
  519. sampling_metadata: SamplingMetadata,
  520. ) -> Optional[torch.Tensor]:
  521. logits = self.logits_processor(self.get_lm_head(), hidden_states,
  522. sampling_metadata)
  523. return logits
  524. def sample(
  525. self,
  526. logits: torch.Tensor,
  527. sampling_metadata: SamplingMetadata,
  528. ) -> Optional[SamplerOutput]:
  529. next_tokens = self.sampler(logits, sampling_metadata)
  530. return next_tokens
  531. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  532. # only doing this for language model part for now.
  533. stacked_params_mapping = [
  534. # (param_name, shard_name, shard_id)
  535. ("qkv_proj", "q_proj", "q"),
  536. ("qkv_proj", "k_proj", "k"),
  537. ("qkv_proj", "v_proj", "v"),
  538. ("gate_up_proj", "gate_proj", 0),
  539. ("gate_up_proj", "up_proj", 1),
  540. ]
  541. params_dict = dict(self.named_parameters())
  542. for name, loaded_weight in weights:
  543. if "lm_head.weight" in name:
  544. continue
  545. if "rotary_emb.inv_freq" in name:
  546. continue
  547. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  548. if key_to_modify in name:
  549. name = name.replace(key_to_modify, new_key)
  550. use_default_weight_loading = False
  551. if "vision" in name:
  552. if self.vision_model is not None:
  553. # BlipVisionModel does not need sharding
  554. use_default_weight_loading = True
  555. else:
  556. for (param_name, weight_name,
  557. shard_id) in stacked_params_mapping:
  558. if weight_name not in name:
  559. continue
  560. param = params_dict[name.replace(weight_name, param_name)]
  561. weight_loader = param.weight_loader
  562. weight_loader(param, loaded_weight, shard_id)
  563. break
  564. else:
  565. use_default_weight_loading = True
  566. if use_default_weight_loading:
  567. param = params_dict[name]
  568. weight_loader = getattr(param, "weight_loader",
  569. default_weight_loader)
  570. weight_loader(param, loaded_weight)