blip2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713
  1. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  2. TypedDict, Union)
  3. import torch
  4. import torch.nn as nn
  5. from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
  6. apply_chunking_to_forward)
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, MultiModalConfig
  9. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  10. SequenceData)
  11. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  14. from aphrodite.modeling.layers.sampler import Sampler
  15. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  16. from aphrodite.modeling.models.opt import OPTModel
  17. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  18. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  19. from aphrodite.quantization import QuantizationConfig
  20. from .blip import (BlipVisionModel, dummy_image_for_blip,
  21. get_max_blip_image_tokens)
  22. from .interfaces import SupportsMultiModal
  23. from .utils import merge_multimodal_embeddings
  24. _KEYS_TO_MODIFY_MAPPING = {
  25. "language_model.lm_head": "lm_head",
  26. "language_model.model": "language_model",
  27. }
  28. # We use this internally as placeholders since there is no image token
  29. # defined on the HuggingFace repo
  30. BLIP2_IMAGE_TOKEN = "<image>"
  31. BLIP2_IMAGE_TOKEN_ID = 50265
  32. class Blip2ImagePixelInputs(TypedDict):
  33. type: Literal["pixel_values"]
  34. data: torch.Tensor
  35. """Shape: (batch_size, num_channels, height, width)"""
  36. class Blip2ImageEmbeddingInputs(TypedDict):
  37. type: Literal["image_embeds"]
  38. data: torch.Tensor
  39. """Shape: `(batch_size, image_feature_size, hidden_size)`
  40. `hidden_size` must match the hidden size of language model backbone.
  41. """
  42. Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
  43. class Blip2QFormerMultiHeadAttention(nn.Module):
  44. def __init__(
  45. self,
  46. config: Blip2QFormerConfig,
  47. *,
  48. quant_config: Optional[QuantizationConfig],
  49. cache_config: Optional[CacheConfig],
  50. is_cross_attention: bool = False,
  51. ) -> None:
  52. super().__init__()
  53. self.config = config
  54. if config.hidden_size % config.num_attention_heads != 0:
  55. raise ValueError(
  56. f"The hidden size ({config.hidden_size}) is not a multiple of "
  57. f"the number of attention heads ({config.num_attention_heads})"
  58. )
  59. self.num_attention_heads = config.num_attention_heads
  60. self.attention_head_size = (config.hidden_size //
  61. config.num_attention_heads)
  62. self.all_head_size = self.num_attention_heads * self.attention_head_size
  63. self.scaling = self.attention_head_size**-0.5
  64. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  65. if is_cross_attention:
  66. kv_hidden_size = config.encoder_hidden_size
  67. else:
  68. kv_hidden_size = config.hidden_size
  69. self.key = nn.Linear(kv_hidden_size, self.all_head_size)
  70. self.value = nn.Linear(kv_hidden_size, self.all_head_size)
  71. self.position_embedding_type = getattr(config,
  72. "position_embedding_type",
  73. "absolute")
  74. if self.position_embedding_type != "absolute":
  75. raise NotImplementedError("Unsupported position_embedding_type: "
  76. f"{self.position_embedding_type}")
  77. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  78. def transpose_for_scores(self, x):
  79. x = x.view(*x.size()[:-1], self.num_attention_heads,
  80. self.attention_head_size)
  81. return x.permute(0, 2, 1, 3)
  82. def forward(
  83. self,
  84. hidden_states: torch.Tensor,
  85. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  86. ):
  87. is_cross_attention = encoder_hidden_states is not None
  88. if is_cross_attention:
  89. key_layer = self.transpose_for_scores(
  90. self.key(encoder_hidden_states))
  91. value_layer = self.transpose_for_scores(
  92. self.value(encoder_hidden_states))
  93. else:
  94. key_layer = self.transpose_for_scores(self.key(hidden_states))
  95. value_layer = self.transpose_for_scores(self.value(hidden_states))
  96. mixed_query_layer = self.query(hidden_states)
  97. query_layer = self.transpose_for_scores(mixed_query_layer)
  98. attention_scores = torch.matmul(query_layer,
  99. key_layer.transpose(-1, -2))
  100. attention_probs = torch.softmax(attention_scores * self.scaling,
  101. dim=-1)
  102. # This is actually dropping out entire tokens to attend to, which might
  103. # seem a bit unusual, but is taken from the original Transformer paper.
  104. attention_probs_dropped = self.dropout(attention_probs)
  105. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  106. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  107. context_layer = context_layer.view(*context_layer.size()[:-2],
  108. self.all_head_size)
  109. return context_layer
  110. class Blip2QFormerSelfOutput(nn.Module):
  111. def __init__(self, config: Blip2QFormerConfig) -> None:
  112. super().__init__()
  113. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  114. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  115. eps=config.layer_norm_eps)
  116. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  117. def forward(
  118. self,
  119. hidden_states: torch.Tensor,
  120. input_tensor: torch.Tensor,
  121. ) -> torch.Tensor:
  122. hidden_states = self.dense(hidden_states)
  123. hidden_states = self.dropout(hidden_states)
  124. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  125. return hidden_states
  126. class Blip2QFormerAttention(nn.Module):
  127. def __init__(
  128. self,
  129. config: Blip2QFormerConfig,
  130. *,
  131. quant_config: Optional[QuantizationConfig],
  132. cache_config: Optional[CacheConfig],
  133. is_cross_attention: bool = False,
  134. ) -> None:
  135. super().__init__()
  136. self.attention = Blip2QFormerMultiHeadAttention(
  137. config,
  138. quant_config=quant_config,
  139. cache_config=cache_config,
  140. is_cross_attention=is_cross_attention,
  141. )
  142. self.output = Blip2QFormerSelfOutput(config)
  143. def forward(
  144. self,
  145. hidden_states: torch.Tensor,
  146. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  147. ) -> Tuple[torch.Tensor]:
  148. self_output = self.attention(
  149. hidden_states,
  150. encoder_hidden_states=encoder_hidden_states,
  151. )
  152. attention_output = self.output(self_output, hidden_states)
  153. return attention_output
  154. class Blip2QFormerIntermediate(nn.Module):
  155. def __init__(self, config: Blip2QFormerConfig) -> None:
  156. super().__init__()
  157. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  158. self.intermediate_act_fn = get_act_fn(config.hidden_act)
  159. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  160. hidden_states = self.dense(hidden_states)
  161. hidden_states = self.intermediate_act_fn(hidden_states)
  162. return hidden_states
  163. class Blip2QFormerOutput(nn.Module):
  164. def __init__(self, config: Blip2QFormerConfig) -> None:
  165. super().__init__()
  166. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  167. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  168. eps=config.layer_norm_eps)
  169. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  170. def forward(
  171. self,
  172. hidden_states: torch.Tensor,
  173. input_tensor: torch.Tensor,
  174. ) -> torch.Tensor:
  175. hidden_states = self.dense(hidden_states)
  176. hidden_states = self.dropout(hidden_states)
  177. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  178. return hidden_states
  179. class Blip2QFormerLayer(nn.Module):
  180. def __init__(
  181. self,
  182. config: Blip2QFormerConfig,
  183. *,
  184. quant_config: Optional[QuantizationConfig],
  185. cache_config: Optional[CacheConfig],
  186. layer_idx: int,
  187. ) -> None:
  188. super().__init__()
  189. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  190. self.seq_len_dim = 1
  191. self.attention = Blip2QFormerAttention(config,
  192. quant_config=quant_config,
  193. cache_config=cache_config)
  194. self.layer_idx = layer_idx
  195. if layer_idx % config.cross_attention_frequency == 0:
  196. self.crossattention = Blip2QFormerAttention(
  197. config,
  198. quant_config=quant_config,
  199. cache_config=cache_config,
  200. is_cross_attention=True)
  201. self.has_cross_attention = True
  202. else:
  203. self.has_cross_attention = False
  204. self.intermediate_query = Blip2QFormerIntermediate(config)
  205. self.output_query = Blip2QFormerOutput(config)
  206. def forward(
  207. self,
  208. hidden_states: torch.FloatTensor,
  209. encoder_hidden_states: torch.FloatTensor,
  210. query_length: int,
  211. ):
  212. attention_output = self.attention(hidden_states)
  213. if query_length > 0:
  214. query_attention_output = attention_output[:, :query_length, :]
  215. if self.has_cross_attention:
  216. query_attention_output = self.crossattention(
  217. query_attention_output,
  218. encoder_hidden_states=encoder_hidden_states,
  219. )
  220. layer_output = apply_chunking_to_forward(
  221. self.feed_forward_chunk_query,
  222. self.chunk_size_feed_forward,
  223. self.seq_len_dim,
  224. query_attention_output,
  225. )
  226. if attention_output.shape[1] > query_length:
  227. layer_output_text = apply_chunking_to_forward(
  228. self.feed_forward_chunk,
  229. self.chunk_size_feed_forward,
  230. self.seq_len_dim,
  231. attention_output[:, query_length:, :],
  232. )
  233. layer_output = torch.cat([layer_output, layer_output_text],
  234. dim=1)
  235. else:
  236. layer_output = apply_chunking_to_forward(
  237. self.feed_forward_chunk,
  238. self.chunk_size_feed_forward,
  239. self.seq_len_dim,
  240. attention_output,
  241. )
  242. return layer_output
  243. def feed_forward_chunk(self,
  244. attention_output: torch.Tensor) -> torch.Tensor:
  245. intermediate_output = self.intermediate(attention_output)
  246. layer_output = self.output(intermediate_output, attention_output)
  247. return layer_output
  248. def feed_forward_chunk_query(
  249. self, attention_output: torch.Tensor) -> torch.Tensor:
  250. intermediate_output = self.intermediate_query(attention_output)
  251. layer_output = self.output_query(intermediate_output, attention_output)
  252. return layer_output
  253. class Blip2QFormerEncoder(nn.Module):
  254. def __init__(
  255. self,
  256. config: Blip2QFormerConfig,
  257. *,
  258. quant_config: Optional[QuantizationConfig],
  259. cache_config: Optional[CacheConfig],
  260. ) -> None:
  261. super().__init__()
  262. self.config = config
  263. self.layer = nn.ModuleList([
  264. Blip2QFormerLayer(config,
  265. quant_config=quant_config,
  266. cache_config=cache_config,
  267. layer_idx=layer_idx)
  268. for layer_idx in range(config.num_hidden_layers)
  269. ])
  270. def forward(
  271. self,
  272. hidden_states: torch.FloatTensor,
  273. encoder_hidden_states: torch.FloatTensor,
  274. query_length: int,
  275. ) -> torch.Tensor:
  276. for i in range(self.config.num_hidden_layers):
  277. layer_module = self.layer[i]
  278. hidden_states = layer_module(
  279. hidden_states,
  280. encoder_hidden_states=encoder_hidden_states,
  281. query_length=query_length,
  282. )
  283. return hidden_states
  284. # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
  285. class Blip2QFormerModel(nn.Module):
  286. def __init__(
  287. self,
  288. config: Blip2QFormerConfig,
  289. *,
  290. quant_config: Optional[QuantizationConfig],
  291. cache_config: Optional[CacheConfig],
  292. ) -> None:
  293. super().__init__()
  294. self.config = config
  295. self.layernorm = nn.LayerNorm(config.hidden_size,
  296. eps=config.layer_norm_eps)
  297. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  298. self.encoder = Blip2QFormerEncoder(config,
  299. quant_config=quant_config,
  300. cache_config=cache_config)
  301. def forward(
  302. self,
  303. query_embeds: torch.FloatTensor,
  304. encoder_hidden_states: torch.FloatTensor,
  305. ) -> torch.Tensor:
  306. query_length = query_embeds.shape[1]
  307. embedding_output = self.layernorm(query_embeds)
  308. embedding_output = self.dropout(embedding_output)
  309. sequence_output = self.encoder(
  310. embedding_output,
  311. encoder_hidden_states=encoder_hidden_states,
  312. query_length=query_length,
  313. )
  314. return sequence_output
  315. def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
  316. return hf_config.num_query_tokens
  317. def get_max_blip2_image_tokens(ctx: InputContext):
  318. hf_config = ctx.get_hf_config(Blip2Config)
  319. vision_config = hf_config.vision_config
  320. if isinstance(vision_config, Blip2VisionConfig):
  321. return get_max_blip_image_tokens(vision_config)
  322. msg = f"Unsupported vision config: {type(vision_config)}"
  323. raise NotImplementedError(msg)
  324. def dummy_seq_data_for_blip2(
  325. hf_config: Blip2Config,
  326. seq_len: int,
  327. num_images: int,
  328. *,
  329. image_token_id: int,
  330. image_feature_size_override: Optional[int] = None,
  331. ):
  332. if image_feature_size_override is None:
  333. image_feature_size = get_blip2_image_feature_size(hf_config)
  334. else:
  335. image_feature_size = image_feature_size_override
  336. token_ids = [image_token_id] * image_feature_size * num_images
  337. token_ids += [0] * (seq_len - image_feature_size * num_images)
  338. return SequenceData(token_ids)
  339. def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
  340. mm_counts: Mapping[str, int]):
  341. hf_config = ctx.get_hf_config(Blip2Config)
  342. vision_config = hf_config.vision_config
  343. num_images = mm_counts["image"]
  344. seq_data = dummy_seq_data_for_blip2(
  345. hf_config,
  346. seq_len,
  347. num_images,
  348. image_token_id=BLIP2_IMAGE_TOKEN_ID,
  349. )
  350. if isinstance(vision_config, Blip2VisionConfig):
  351. mm_data = dummy_image_for_blip(vision_config, num_images)
  352. return seq_data, mm_data
  353. msg = f"Unsupported vision config: {type(vision_config)}"
  354. raise NotImplementedError(msg)
  355. def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
  356. multi_modal_data = llm_inputs.get("multi_modal_data")
  357. if multi_modal_data is None or "image" not in multi_modal_data:
  358. return llm_inputs
  359. hf_config = ctx.get_hf_config(Blip2Config)
  360. image_feature_size = get_blip2_image_feature_size(hf_config)
  361. # The original model places image tokens at the front
  362. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
  363. new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
  364. new_token_ids += llm_inputs["prompt_token_ids"]
  365. new_prompt = llm_inputs.get("prompt")
  366. if new_prompt is not None:
  367. new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
  368. return LLMInputs(prompt_token_ids=new_token_ids,
  369. prompt=new_prompt,
  370. multi_modal_data=multi_modal_data)
  371. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  372. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
  373. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
  374. @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
  375. class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
  376. def __init__(self,
  377. config: Blip2Config,
  378. multimodal_config: MultiModalConfig,
  379. cache_config: Optional[CacheConfig] = None,
  380. quant_config: Optional[QuantizationConfig] = None) -> None:
  381. super().__init__()
  382. self.config = config
  383. self.multimodal_config = multimodal_config
  384. # TODO: Optionally initializes this for supporting embeddings.
  385. self.vision_model = BlipVisionModel(config.vision_config)
  386. self.query_tokens = nn.Parameter(
  387. torch.zeros(1, config.num_query_tokens,
  388. config.qformer_config.hidden_size))
  389. self.qformer = Blip2QFormerModel(config.qformer_config,
  390. cache_config=cache_config,
  391. quant_config=quant_config)
  392. self.language_projection = nn.Linear(
  393. config.qformer_config.hidden_size,
  394. config.text_config.hidden_size,
  395. bias=True,
  396. )
  397. self.quant_config = quant_config
  398. self.language_model = OPTModel(config.text_config, cache_config,
  399. quant_config)
  400. self.unpadded_vocab_size = config.text_config.vocab_size
  401. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
  402. self.sampler = Sampler()
  403. def get_lm_head(self):
  404. return self.language_model.decoder.embed_tokens
  405. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  406. h = w = self.config.vision_config.image_size
  407. expected_dims = (3, h, w)
  408. actual_dims = tuple(data.shape[1:])
  409. if actual_dims != expected_dims:
  410. expected_expr = ("batch_size", *map(str, expected_dims))
  411. raise ValueError(
  412. f"The expected shape of pixel values is {expected_expr}. "
  413. f"You supplied {tuple(data.shape)}.")
  414. return data
  415. def _parse_and_validate_image_input(
  416. self, **kwargs: object) -> Optional[Blip2ImageInputs]:
  417. pixel_values = kwargs.pop("pixel_values", None)
  418. image_embeds = kwargs.pop("image_embeds", None)
  419. if pixel_values is None and image_embeds is None:
  420. return None
  421. if pixel_values is not None:
  422. if not isinstance(pixel_values, torch.Tensor):
  423. raise ValueError("Incorrect type of pixel values. "
  424. f"Got type: {type(pixel_values)}")
  425. return Blip2ImagePixelInputs(
  426. type="pixel_values",
  427. data=self._validate_pixel_values(pixel_values),
  428. )
  429. if image_embeds is not None:
  430. if not isinstance(image_embeds, torch.Tensor):
  431. raise ValueError("Incorrect type of image embeddings. "
  432. f"Got type: {type(image_embeds)}")
  433. return Blip2ImageEmbeddingInputs(
  434. type="image_embeds",
  435. data=image_embeds,
  436. )
  437. raise AssertionError("This line should be unreachable.")
  438. def _image_pixels_to_features(self, vision_model: BlipVisionModel,
  439. pixel_values: torch.Tensor) -> torch.Tensor:
  440. # NOTE: we skip the step to select the vision feature layer since
  441. # this is already done inside the vision tower
  442. image_features = vision_model(pixel_values)
  443. return image_features
  444. def _process_image_pixels(self,
  445. inputs: Blip2ImagePixelInputs) -> torch.Tensor:
  446. assert self.vision_model is not None
  447. pixel_values = inputs["data"]
  448. return self._image_pixels_to_features(self.vision_model, pixel_values)
  449. def _process_image_input(self,
  450. image_input: Blip2ImageInputs) -> torch.Tensor:
  451. if image_input["type"] == "image_embeds":
  452. return image_input["data"]
  453. assert self.vision_model is not None
  454. image_features = self._process_image_pixels(image_input)
  455. query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
  456. -1)
  457. query_output = self.qformer(
  458. query_embeds=query_tokens,
  459. encoder_hidden_states=image_features,
  460. )
  461. return self.language_projection(query_output)
  462. def forward(
  463. self,
  464. input_ids: torch.Tensor,
  465. positions: torch.Tensor,
  466. kv_caches: List[torch.Tensor],
  467. attn_metadata: AttentionMetadata,
  468. intermediate_tensors: Optional[IntermediateTensors] = None,
  469. **kwargs: object,
  470. ) -> SamplerOutput:
  471. """Run forward pass for BLIP-2.
  472. One key thing to understand is the `input_ids` already accounts for the
  473. positions of the to-be-inserted image embeddings.
  474. Concretely, consider a text prompt:
  475. `"Question: What's the content of the image? Answer:"`.
  476. Tokenizer outputs:
  477. `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
  478. To reserve space in KV cache, we have to insert placeholder tokens
  479. before they are inputted to the model, so the input processor prepends
  480. dummy tokens (denoted as `50265`), resulting in:
  481. `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
  482. We insert 32 tokens since it corresponds to the number of query
  483. embeddings outputted by the Q-Former and inputted to the language model.
  484. This way, the `positions` and `attn_metadata` are consistent
  485. with the `input_ids`.
  486. Args:
  487. input_ids: Flattened (concatenated) input_ids corresponding to a
  488. batch.
  489. pixel_values: The pixels in each input image.
  490. See also:
  491. :class:`Blip2ImageInputs`
  492. """
  493. image_input = self._parse_and_validate_image_input(**kwargs)
  494. if image_input is not None:
  495. vision_embeddings = self._process_image_input(image_input)
  496. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  497. inputs_embeds = merge_multimodal_embeddings(
  498. input_ids, inputs_embeds, vision_embeddings,
  499. BLIP2_IMAGE_TOKEN_ID)
  500. input_ids = None
  501. else:
  502. inputs_embeds = None
  503. hidden_states = self.language_model(input_ids,
  504. positions,
  505. kv_caches,
  506. attn_metadata,
  507. inputs_embeds=inputs_embeds)
  508. return hidden_states
  509. def compute_logits(
  510. self,
  511. hidden_states: torch.Tensor,
  512. sampling_metadata: SamplingMetadata,
  513. ) -> Optional[torch.Tensor]:
  514. logits = self.logits_processor(self.get_lm_head(), hidden_states,
  515. sampling_metadata)
  516. return logits
  517. def sample(
  518. self,
  519. logits: torch.Tensor,
  520. sampling_metadata: SamplingMetadata,
  521. ) -> Optional[SamplerOutput]:
  522. next_tokens = self.sampler(logits, sampling_metadata)
  523. return next_tokens
  524. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  525. # only doing this for language model part for now.
  526. stacked_params_mapping = [
  527. # (param_name, shard_name, shard_id)
  528. ("qkv_proj", "q_proj", "q"),
  529. ("qkv_proj", "k_proj", "k"),
  530. ("qkv_proj", "v_proj", "v"),
  531. ("gate_up_proj", "gate_proj", 0),
  532. ("gate_up_proj", "up_proj", 1),
  533. ]
  534. params_dict = dict(self.named_parameters())
  535. for name, loaded_weight in weights:
  536. if "lm_head.weight" in name:
  537. continue
  538. if "rotary_emb.inv_freq" in name:
  539. continue
  540. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  541. if key_to_modify in name:
  542. name = name.replace(key_to_modify, new_key)
  543. use_default_weight_loading = False
  544. if "vision" in name:
  545. if self.vision_model is not None:
  546. # We only do sharding for language model and
  547. # not vision model for now.
  548. use_default_weight_loading = True
  549. else:
  550. for (param_name, weight_name,
  551. shard_id) in stacked_params_mapping:
  552. if weight_name not in name:
  553. continue
  554. param = params_dict[name.replace(weight_name, param_name)]
  555. weight_loader = param.weight_loader
  556. weight_loader(param, loaded_weight, shard_id)
  557. break
  558. else:
  559. use_default_weight_loading = True
  560. if use_default_weight_loading:
  561. param = params_dict[name]
  562. weight_loader = getattr(param, "weight_loader",
  563. default_weight_loader)
  564. weight_loader(param, loaded_weight)