blip2.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. from array import array
  2. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  3. TypedDict, Union)
  4. import torch
  5. import torch.nn as nn
  6. from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
  7. apply_chunking_to_forward)
  8. from aphrodite.attention import AttentionMetadata
  9. from aphrodite.common.config import CacheConfig, MultiModalConfig
  10. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  11. SequenceData)
  12. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  13. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  14. from aphrodite.modeling.layers.activation import get_act_fn
  15. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  16. from aphrodite.modeling.layers.sampler import Sampler
  17. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  18. from aphrodite.modeling.models.opt import OPTModel
  19. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  20. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  21. from aphrodite.quantization import QuantizationConfig
  22. from .blip import (BlipVisionModel, dummy_image_for_blip,
  23. get_max_blip_image_tokens)
  24. from .interfaces import SupportsMultiModal
  25. from .utils import merge_multimodal_embeddings
  26. _KEYS_TO_MODIFY_MAPPING = {
  27. "language_model.lm_head": "lm_head",
  28. "language_model.model": "language_model",
  29. }
  30. # We use this internally as placeholders since there is no image token
  31. # defined on the HuggingFace repo
  32. BLIP2_IMAGE_TOKEN = "<image>"
  33. BLIP2_IMAGE_TOKEN_ID = 50265
  34. class Blip2ImagePixelInputs(TypedDict):
  35. type: Literal["pixel_values"]
  36. data: torch.Tensor
  37. """Shape: (batch_size, num_channels, height, width)"""
  38. class Blip2ImageEmbeddingInputs(TypedDict):
  39. type: Literal["image_embeds"]
  40. data: torch.Tensor
  41. """Shape: `(batch_size, image_feature_size, hidden_size)`
  42. `hidden_size` must match the hidden size of language model backbone.
  43. """
  44. Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
  45. class Blip2QFormerMultiHeadAttention(nn.Module):
  46. def __init__(
  47. self,
  48. config: Blip2QFormerConfig,
  49. *,
  50. quant_config: Optional[QuantizationConfig],
  51. cache_config: Optional[CacheConfig],
  52. is_cross_attention: bool = False,
  53. ) -> None:
  54. super().__init__()
  55. self.config = config
  56. if config.hidden_size % config.num_attention_heads != 0:
  57. raise ValueError(
  58. f"The hidden size ({config.hidden_size}) is not a multiple of "
  59. f"the number of attention heads ({config.num_attention_heads})"
  60. )
  61. self.num_attention_heads = config.num_attention_heads
  62. self.attention_head_size = (config.hidden_size //
  63. config.num_attention_heads)
  64. self.all_head_size = self.num_attention_heads * self.attention_head_size
  65. self.scaling = self.attention_head_size**-0.5
  66. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  67. if is_cross_attention:
  68. kv_hidden_size = config.encoder_hidden_size
  69. else:
  70. kv_hidden_size = config.hidden_size
  71. self.key = nn.Linear(kv_hidden_size, self.all_head_size)
  72. self.value = nn.Linear(kv_hidden_size, self.all_head_size)
  73. self.position_embedding_type = getattr(config,
  74. "position_embedding_type",
  75. "absolute")
  76. if self.position_embedding_type != "absolute":
  77. raise NotImplementedError("Unsupported position_embedding_type: "
  78. f"{self.position_embedding_type}")
  79. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  80. def transpose_for_scores(self, x):
  81. x = x.view(*x.size()[:-1], self.num_attention_heads,
  82. self.attention_head_size)
  83. return x.permute(0, 2, 1, 3)
  84. def forward(
  85. self,
  86. hidden_states: torch.Tensor,
  87. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  88. ):
  89. is_cross_attention = encoder_hidden_states is not None
  90. if is_cross_attention:
  91. key_layer = self.transpose_for_scores(
  92. self.key(encoder_hidden_states))
  93. value_layer = self.transpose_for_scores(
  94. self.value(encoder_hidden_states))
  95. else:
  96. key_layer = self.transpose_for_scores(self.key(hidden_states))
  97. value_layer = self.transpose_for_scores(self.value(hidden_states))
  98. mixed_query_layer = self.query(hidden_states)
  99. query_layer = self.transpose_for_scores(mixed_query_layer)
  100. attention_scores = torch.matmul(query_layer,
  101. key_layer.transpose(-1, -2))
  102. attention_probs = torch.softmax(attention_scores * self.scaling,
  103. dim=-1)
  104. # This is actually dropping out entire tokens to attend to, which might
  105. # seem a bit unusual, but is taken from the original Transformer paper.
  106. attention_probs_dropped = self.dropout(attention_probs)
  107. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  108. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  109. context_layer = context_layer.view(*context_layer.size()[:-2],
  110. self.all_head_size)
  111. return context_layer
  112. class Blip2QFormerSelfOutput(nn.Module):
  113. def __init__(self, config: Blip2QFormerConfig) -> None:
  114. super().__init__()
  115. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  116. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  117. eps=config.layer_norm_eps)
  118. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  119. def forward(
  120. self,
  121. hidden_states: torch.Tensor,
  122. input_tensor: torch.Tensor,
  123. ) -> torch.Tensor:
  124. hidden_states = self.dense(hidden_states)
  125. hidden_states = self.dropout(hidden_states)
  126. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  127. return hidden_states
  128. class Blip2QFormerAttention(nn.Module):
  129. def __init__(
  130. self,
  131. config: Blip2QFormerConfig,
  132. *,
  133. quant_config: Optional[QuantizationConfig],
  134. cache_config: Optional[CacheConfig],
  135. is_cross_attention: bool = False,
  136. ) -> None:
  137. super().__init__()
  138. self.attention = Blip2QFormerMultiHeadAttention(
  139. config,
  140. quant_config=quant_config,
  141. cache_config=cache_config,
  142. is_cross_attention=is_cross_attention,
  143. )
  144. self.output = Blip2QFormerSelfOutput(config)
  145. def forward(
  146. self,
  147. hidden_states: torch.Tensor,
  148. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  149. ) -> Tuple[torch.Tensor]:
  150. self_output = self.attention(
  151. hidden_states,
  152. encoder_hidden_states=encoder_hidden_states,
  153. )
  154. attention_output = self.output(self_output, hidden_states)
  155. return attention_output
  156. class Blip2QFormerIntermediate(nn.Module):
  157. def __init__(self, config: Blip2QFormerConfig) -> None:
  158. super().__init__()
  159. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  160. self.intermediate_act_fn = get_act_fn(config.hidden_act)
  161. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  162. hidden_states = self.dense(hidden_states)
  163. hidden_states = self.intermediate_act_fn(hidden_states)
  164. return hidden_states
  165. class Blip2QFormerOutput(nn.Module):
  166. def __init__(self, config: Blip2QFormerConfig) -> None:
  167. super().__init__()
  168. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  169. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  170. eps=config.layer_norm_eps)
  171. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  172. def forward(
  173. self,
  174. hidden_states: torch.Tensor,
  175. input_tensor: torch.Tensor,
  176. ) -> torch.Tensor:
  177. hidden_states = self.dense(hidden_states)
  178. hidden_states = self.dropout(hidden_states)
  179. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  180. return hidden_states
  181. class Blip2QFormerLayer(nn.Module):
  182. def __init__(
  183. self,
  184. config: Blip2QFormerConfig,
  185. *,
  186. quant_config: Optional[QuantizationConfig],
  187. cache_config: Optional[CacheConfig],
  188. layer_idx: int,
  189. ) -> None:
  190. super().__init__()
  191. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  192. self.seq_len_dim = 1
  193. self.attention = Blip2QFormerAttention(config,
  194. quant_config=quant_config,
  195. cache_config=cache_config)
  196. self.layer_idx = layer_idx
  197. if layer_idx % config.cross_attention_frequency == 0:
  198. self.crossattention = Blip2QFormerAttention(
  199. config,
  200. quant_config=quant_config,
  201. cache_config=cache_config,
  202. is_cross_attention=True)
  203. self.has_cross_attention = True
  204. else:
  205. self.has_cross_attention = False
  206. self.intermediate_query = Blip2QFormerIntermediate(config)
  207. self.output_query = Blip2QFormerOutput(config)
  208. def forward(
  209. self,
  210. hidden_states: torch.FloatTensor,
  211. encoder_hidden_states: torch.FloatTensor,
  212. query_length: int,
  213. ):
  214. attention_output = self.attention(hidden_states)
  215. if query_length > 0:
  216. query_attention_output = attention_output[:, :query_length, :]
  217. if self.has_cross_attention:
  218. query_attention_output = self.crossattention(
  219. query_attention_output,
  220. encoder_hidden_states=encoder_hidden_states,
  221. )
  222. layer_output = apply_chunking_to_forward(
  223. self.feed_forward_chunk_query,
  224. self.chunk_size_feed_forward,
  225. self.seq_len_dim,
  226. query_attention_output,
  227. )
  228. if attention_output.shape[1] > query_length:
  229. layer_output_text = apply_chunking_to_forward(
  230. self.feed_forward_chunk,
  231. self.chunk_size_feed_forward,
  232. self.seq_len_dim,
  233. attention_output[:, query_length:, :],
  234. )
  235. layer_output = torch.cat([layer_output, layer_output_text],
  236. dim=1)
  237. else:
  238. layer_output = apply_chunking_to_forward(
  239. self.feed_forward_chunk,
  240. self.chunk_size_feed_forward,
  241. self.seq_len_dim,
  242. attention_output,
  243. )
  244. return layer_output
  245. def feed_forward_chunk(self,
  246. attention_output: torch.Tensor) -> torch.Tensor:
  247. intermediate_output = self.intermediate(attention_output)
  248. layer_output = self.output(intermediate_output, attention_output)
  249. return layer_output
  250. def feed_forward_chunk_query(
  251. self, attention_output: torch.Tensor) -> torch.Tensor:
  252. intermediate_output = self.intermediate_query(attention_output)
  253. layer_output = self.output_query(intermediate_output, attention_output)
  254. return layer_output
  255. class Blip2QFormerEncoder(nn.Module):
  256. def __init__(
  257. self,
  258. config: Blip2QFormerConfig,
  259. *,
  260. quant_config: Optional[QuantizationConfig],
  261. cache_config: Optional[CacheConfig],
  262. ) -> None:
  263. super().__init__()
  264. self.config = config
  265. self.layer = nn.ModuleList([
  266. Blip2QFormerLayer(config,
  267. quant_config=quant_config,
  268. cache_config=cache_config,
  269. layer_idx=layer_idx)
  270. for layer_idx in range(config.num_hidden_layers)
  271. ])
  272. def forward(
  273. self,
  274. hidden_states: torch.FloatTensor,
  275. encoder_hidden_states: torch.FloatTensor,
  276. query_length: int,
  277. ) -> torch.Tensor:
  278. for i in range(self.config.num_hidden_layers):
  279. layer_module = self.layer[i]
  280. hidden_states = layer_module(
  281. hidden_states,
  282. encoder_hidden_states=encoder_hidden_states,
  283. query_length=query_length,
  284. )
  285. return hidden_states
  286. # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
  287. class Blip2QFormerModel(nn.Module):
  288. def __init__(
  289. self,
  290. config: Blip2QFormerConfig,
  291. *,
  292. quant_config: Optional[QuantizationConfig],
  293. cache_config: Optional[CacheConfig],
  294. ) -> None:
  295. super().__init__()
  296. self.config = config
  297. self.layernorm = nn.LayerNorm(config.hidden_size,
  298. eps=config.layer_norm_eps)
  299. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  300. self.encoder = Blip2QFormerEncoder(config,
  301. quant_config=quant_config,
  302. cache_config=cache_config)
  303. def forward(
  304. self,
  305. query_embeds: torch.FloatTensor,
  306. encoder_hidden_states: torch.FloatTensor,
  307. ) -> torch.Tensor:
  308. query_length = query_embeds.shape[1]
  309. embedding_output = self.layernorm(query_embeds)
  310. embedding_output = self.dropout(embedding_output)
  311. sequence_output = self.encoder(
  312. embedding_output,
  313. encoder_hidden_states=encoder_hidden_states,
  314. query_length=query_length,
  315. )
  316. return sequence_output
  317. def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
  318. return hf_config.num_query_tokens
  319. def get_max_blip2_image_tokens(ctx: InputContext):
  320. hf_config = ctx.get_hf_config(Blip2Config)
  321. vision_config = hf_config.vision_config
  322. if isinstance(vision_config, Blip2VisionConfig):
  323. return get_max_blip_image_tokens(vision_config)
  324. msg = f"Unsupported vision config: {type(vision_config)}"
  325. raise NotImplementedError(msg)
  326. def dummy_seq_data_for_blip2(
  327. hf_config: Blip2Config,
  328. seq_len: int,
  329. num_images: int,
  330. *,
  331. image_token_id: int,
  332. image_feature_size_override: Optional[int] = None,
  333. ):
  334. if image_feature_size_override is None:
  335. image_feature_size = get_blip2_image_feature_size(hf_config)
  336. else:
  337. image_feature_size = image_feature_size_override
  338. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  339. [image_token_id]) * image_feature_size * num_images
  340. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  341. [0]) * (seq_len - image_feature_size * num_images)
  342. return SequenceData(token_ids)
  343. def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
  344. mm_counts: Mapping[str, int]):
  345. hf_config = ctx.get_hf_config(Blip2Config)
  346. vision_config = hf_config.vision_config
  347. num_images = mm_counts["image"]
  348. seq_data = dummy_seq_data_for_blip2(
  349. hf_config,
  350. seq_len,
  351. num_images,
  352. image_token_id=BLIP2_IMAGE_TOKEN_ID,
  353. )
  354. if isinstance(vision_config, Blip2VisionConfig):
  355. mm_data = dummy_image_for_blip(vision_config, num_images)
  356. return seq_data, mm_data
  357. msg = f"Unsupported vision config: {type(vision_config)}"
  358. raise NotImplementedError(msg)
  359. def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
  360. multi_modal_data = llm_inputs.get("multi_modal_data")
  361. if multi_modal_data is None or "image" not in multi_modal_data:
  362. return llm_inputs
  363. hf_config = ctx.get_hf_config(Blip2Config)
  364. image_feature_size = get_blip2_image_feature_size(hf_config)
  365. # The original model places image tokens at the front
  366. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
  367. new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
  368. new_token_ids += llm_inputs["prompt_token_ids"]
  369. new_prompt = llm_inputs.get("prompt")
  370. if new_prompt is not None:
  371. new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
  372. return LLMInputs(prompt_token_ids=new_token_ids,
  373. prompt=new_prompt,
  374. multi_modal_data=multi_modal_data)
  375. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  376. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
  377. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
  378. @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
  379. class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
  380. def __init__(self,
  381. config: Blip2Config,
  382. multimodal_config: MultiModalConfig,
  383. cache_config: Optional[CacheConfig] = None,
  384. quant_config: Optional[QuantizationConfig] = None) -> None:
  385. super().__init__()
  386. self.config = config
  387. self.multimodal_config = multimodal_config
  388. # TODO: Optionally initializes this for supporting embeddings.
  389. self.vision_model = BlipVisionModel(config.vision_config)
  390. self.query_tokens = nn.Parameter(
  391. torch.zeros(1, config.num_query_tokens,
  392. config.qformer_config.hidden_size))
  393. self.qformer = Blip2QFormerModel(config.qformer_config,
  394. cache_config=cache_config,
  395. quant_config=quant_config)
  396. self.language_projection = nn.Linear(
  397. config.qformer_config.hidden_size,
  398. config.text_config.hidden_size,
  399. bias=True,
  400. )
  401. self.quant_config = quant_config
  402. self.language_model = OPTModel(config.text_config, cache_config,
  403. quant_config)
  404. self.unpadded_vocab_size = config.text_config.vocab_size
  405. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
  406. self.sampler = Sampler()
  407. def get_lm_head(self):
  408. return self.language_model.decoder.embed_tokens
  409. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  410. h = w = self.config.vision_config.image_size
  411. expected_dims = (3, h, w)
  412. actual_dims = tuple(data.shape[1:])
  413. if actual_dims != expected_dims:
  414. expected_expr = ("batch_size", *map(str, expected_dims))
  415. raise ValueError(
  416. f"The expected shape of pixel values is {expected_expr}. "
  417. f"You supplied {tuple(data.shape)}.")
  418. return data
  419. def _parse_and_validate_image_input(
  420. self, **kwargs: object) -> Optional[Blip2ImageInputs]:
  421. pixel_values = kwargs.pop("pixel_values", None)
  422. image_embeds = kwargs.pop("image_embeds", None)
  423. if pixel_values is None and image_embeds is None:
  424. return None
  425. if pixel_values is not None:
  426. if not isinstance(pixel_values, torch.Tensor):
  427. raise ValueError("Incorrect type of pixel values. "
  428. f"Got type: {type(pixel_values)}")
  429. return Blip2ImagePixelInputs(
  430. type="pixel_values",
  431. data=self._validate_pixel_values(pixel_values),
  432. )
  433. if image_embeds is not None:
  434. if not isinstance(image_embeds, torch.Tensor):
  435. raise ValueError("Incorrect type of image embeddings. "
  436. f"Got type: {type(image_embeds)}")
  437. return Blip2ImageEmbeddingInputs(
  438. type="image_embeds",
  439. data=image_embeds,
  440. )
  441. raise AssertionError("This line should be unreachable.")
  442. def _image_pixels_to_features(self, vision_model: BlipVisionModel,
  443. pixel_values: torch.Tensor) -> torch.Tensor:
  444. # NOTE: we skip the step to select the vision feature layer since
  445. # this is already done inside the vision tower
  446. image_features = vision_model(pixel_values)
  447. return image_features
  448. def _process_image_pixels(self,
  449. inputs: Blip2ImagePixelInputs) -> torch.Tensor:
  450. assert self.vision_model is not None
  451. pixel_values = inputs["data"]
  452. return self._image_pixels_to_features(self.vision_model, pixel_values)
  453. def _process_image_input(self,
  454. image_input: Blip2ImageInputs) -> torch.Tensor:
  455. if image_input["type"] == "image_embeds":
  456. return image_input["data"]
  457. assert self.vision_model is not None
  458. image_features = self._process_image_pixels(image_input)
  459. query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
  460. -1)
  461. query_output = self.qformer(
  462. query_embeds=query_tokens,
  463. encoder_hidden_states=image_features,
  464. )
  465. return self.language_projection(query_output)
  466. def forward(
  467. self,
  468. input_ids: torch.Tensor,
  469. positions: torch.Tensor,
  470. kv_caches: List[torch.Tensor],
  471. attn_metadata: AttentionMetadata,
  472. intermediate_tensors: Optional[IntermediateTensors] = None,
  473. **kwargs: object,
  474. ) -> SamplerOutput:
  475. """Run forward pass for BLIP-2.
  476. One key thing to understand is the `input_ids` already accounts for the
  477. positions of the to-be-inserted image embeddings.
  478. Concretely, consider a text prompt:
  479. `"Question: What's the content of the image? Answer:"`.
  480. Tokenizer outputs:
  481. `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
  482. To reserve space in KV cache, we have to insert placeholder tokens
  483. before they are inputted to the model, so the input processor prepends
  484. dummy tokens (denoted as `50265`), resulting in:
  485. `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
  486. We insert 32 tokens since it corresponds to the number of query
  487. embeddings outputted by the Q-Former and inputted to the language model.
  488. This way, the `positions` and `attn_metadata` are consistent
  489. with the `input_ids`.
  490. Args:
  491. input_ids: Flattened (concatenated) input_ids corresponding to a
  492. batch.
  493. pixel_values: The pixels in each input image.
  494. See also:
  495. :class:`Blip2ImageInputs`
  496. """
  497. image_input = self._parse_and_validate_image_input(**kwargs)
  498. if image_input is not None:
  499. vision_embeddings = self._process_image_input(image_input)
  500. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  501. inputs_embeds = merge_multimodal_embeddings(
  502. input_ids, inputs_embeds, vision_embeddings,
  503. BLIP2_IMAGE_TOKEN_ID)
  504. input_ids = None
  505. else:
  506. inputs_embeds = None
  507. hidden_states = self.language_model(input_ids,
  508. positions,
  509. kv_caches,
  510. attn_metadata,
  511. inputs_embeds=inputs_embeds)
  512. return hidden_states
  513. def compute_logits(
  514. self,
  515. hidden_states: torch.Tensor,
  516. sampling_metadata: SamplingMetadata,
  517. ) -> Optional[torch.Tensor]:
  518. logits = self.logits_processor(self.get_lm_head(), hidden_states,
  519. sampling_metadata)
  520. return logits
  521. def sample(
  522. self,
  523. logits: torch.Tensor,
  524. sampling_metadata: SamplingMetadata,
  525. ) -> Optional[SamplerOutput]:
  526. next_tokens = self.sampler(logits, sampling_metadata)
  527. return next_tokens
  528. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  529. # only doing this for language model part for now.
  530. stacked_params_mapping = [
  531. # (param_name, shard_name, shard_id)
  532. ("qkv_proj", "q_proj", "q"),
  533. ("qkv_proj", "k_proj", "k"),
  534. ("qkv_proj", "v_proj", "v"),
  535. ("gate_up_proj", "gate_proj", 0),
  536. ("gate_up_proj", "up_proj", 1),
  537. ]
  538. params_dict = dict(self.named_parameters())
  539. for name, loaded_weight in weights:
  540. if "lm_head.weight" in name:
  541. continue
  542. if "rotary_emb.inv_freq" in name:
  543. continue
  544. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  545. if key_to_modify in name:
  546. name = name.replace(key_to_modify, new_key)
  547. use_default_weight_loading = False
  548. if "vision" in name:
  549. if self.vision_model is not None:
  550. # We only do sharding for language model and
  551. # not vision model for now.
  552. use_default_weight_loading = True
  553. else:
  554. for (param_name, weight_name,
  555. shard_id) in stacked_params_mapping:
  556. if weight_name not in name:
  557. continue
  558. param = params_dict[name.replace(weight_name, param_name)]
  559. weight_loader = param.weight_loader
  560. weight_loader(param, loaded_weight, shard_id)
  561. break
  562. else:
  563. use_default_weight_loading = True
  564. if use_default_weight_loading:
  565. param = params_dict[name]
  566. weight_loader = getattr(param, "weight_loader",
  567. default_weight_loader)
  568. weight_loader(param, loaded_weight)